-
[cs231n] 3강 손실 함수와 최적화 (2/4, 정규화 (regularization)와 소프트맥스 (softmax))AI 2021. 1. 11. 18:25
그건 여기 써 있는 것 때문인데요. 데이타에 관해 손실만 썼는데, 우리의 학습 데이터에 맞는 분류기에 대한 W를 찾아야 한다고 얘기했죠. 하지만 실제로, 학습 데이타에 맞추는 것에는 그렇게 관심이 없습니다. 머신러닝의 전체적인 요점은 학습 데이타를 사용해서 어떤 분류기를 찾는 건데, 그리고 그걸 테스트 데이타에 적용하는 거죠. 그래서 우리는 훈련 데이타 성능에 관심이 없고 테스트 데이타에 대한 분류기 성능이 중요합니다. 결과적으로, 우리가 분류기에게 얘기하는 것은, 훈련 데이타에 핏 (fit)하라는 건데, 어떤 경우 가끔 이상한 상황으로 우리를 몰고 갑니다. 분류기가 비직관적인 행동하는 거죠. 그래서 구체적이고 표준적인 예는 선형 분류기가 아닌, 약간 더 일반적인 머신러닝 개념에 대해 얘기할 겁니다.
이런 파란 점들의 데이타 셋이 있다면, 우리는 이 파란 점들인 학습 데이타에 맞는 어떤 곡선을 맞춰 (fit)볼 겁니다. 우리가 분류기에게 하라고 얘기하는 것은 학습 데이타에 맞추라는 것이고 그럼 모든 학습 데이타 점에 대해서 완벽하게 분류하게 하기 위해서 매우 구불구불한 곡선을 가집니다. 이건 나쁜 거죠.. 우리는 이 성능에 관심이 없으며, 우리는 테스트 데이타에 대한 성능에 신경을 씁니다.
그래서 만약 같은 트렌드를 따르는 새로운 데이타가 들어오면, 구불구불한 파란선이 완전히 틀리게 되는거죠.
그럼 사실, 우리는 아마도 분류기가 학습 데이타에 완벽히 맞는 이 매우 복잡하고 구불구불한 선 보다는 초록 직선을 예측하는 걸 더 좋아할텐데요. 이것은 머신러닝에서 핵심 기본 문제입니다.
우리가 보통 이걸 해결하는 방법은 정규화 컨셉입니다. 우리는 손실함수에 추가적인 항을 데이타 손실에 더해, 우리의 분류기에게 학습 데이타에 맞추라고 얘기하는 거고, 전형적으로 정규화 항이라고 불리는 다른 항을 손실함수에 더합니다. 그건 모델이 어느정도 단순한 W를 고르도록 유도하는데요. 단순함의 개념은 작업과 모델에 따라 다릅니다.
오캄의 면도날이라는 이 과학적 아이디어는 더 넓게는 근본적 아이디어의 과학적 발견에 있는데요. 만약 여러분의 관찰을 설명할 수 있는 경쟁하는 가설들이 있다면, 일반적으로 더 간단한 것을 선호해야 한다는 거죠. 왜냐면 그것이 미래의 관찰을 일반화할 수 있는 설명이라는 겁니다. 그리고 우리가 이 직관을 머신러닝에 적용하면, 전형적으로 어떤 명시적 정규화 페널티를 거칩니다, 종종 R이라고 적는데요. 그래서 여러분의 표준 손실 함수는 이 2개의 항을 가지는데, 데이타 손실과 정규화 손실이죠 . 그리고 둘 사이에 트래이드오프 (trade-off)하는 람다인 어떤 하이퍼파라미터가 있구요. 우리는 하이퍼파라미터와 교차 검증에 대해 지난 시간에 얘기했는데, 이 정규화 하이퍼파라미터 람다는 여러분이 이 모델을 실제로 훈련시킬 때 튜닝해야 하는 가장 중요한 것 중 하나일 겁니다.
사실 실제로 사용되는 여러 종류의 정규화가 있어요. 가장 흔히 쓰는 건 아마 L2 정규화 혹은 가중치 감쇠입니다. L2 정규화는 이 가중치 벡터 W에 대한 유클리드 놈 (norm)인데, 때로는 제곱 놈 혹은 가끔 반 제곱 놈을 쓰기도 합니다. 왜냐면, 이게 미분을 없애 줘서 더 낫거든요. L2 정규화 아이디어는 여러분이 단지 이 가중치 벡터의 유클리드 놈에 패널티를 주는 겁니다. 가끔은 L1 정규화도 볼텐데, 가중치 벡터의 L1 놈에게 패널티를 주는 거죠. L1 정규화는 좋은 특징들이 있는데, 예를 들면 매트릭스 W의 희소성 (sparsity)를 증가시킵니다. 볼만한 다른 것은 엘라스틱 넷 정규화인데, L1과 L2를 조합한 거죠. 때로는 최대 놈 (max norm) 정규화도 볼 수 있는데, L1, L2 놈이 아닌 최대 놈에 패널티를 주는 겁니다. 이런 종류의 정규화는 딥러닝에서만 보는 게 아니고, 머신 러닝의 많은 영역에서 보는 거고 더 넓게는 최적화에서도 보입니다. 나중에 딥러닝에 더 특화된 정규화 몇 개를 볼 건데요. 예를 들면 드랍아웃 (dropout)이죠. 배치 정규화 (batch normalization), 확률적 깊이 (stochastic depth) 등도 있죠. 이런 것들은 최근 몇 년간 미쳤는데, 전체적으로 정규화 아이디어는 어러분의 모델에 어떤 식으로는 명시적으로 학습데이타에 맞추려고 (fit) 하기 보다는 모델의 복잡성을 불리하게 하는 모든 것이 될 수 있죠.
L2 정규화는 모델의 복잡도를 어떻게 측정하죠? 고맙게도 여기 예를 살펴보죠. 우리 학습 예제 x가 있고, 고려하고 있는 W 2개가 있죠. x는 4개의 1로 된 벡터고, 우리는 2개의 w를 고려하고 있는데, 하나는 [1, 0, 0, 0]이고 다른 하나는 [0.25, 025, 0.25, 0.25]죠. 우리가 선형 분류를 할 때, 우리는 x와 w간의 내적을 취하죠. 선형 분류 관점에서, 이 2개의 w는 같죠. 왜냐면 같은 결과를 내니까요. 질문은 이 두개 예를 x와 내적하면, 어떤 것을 L2 회기가 선호할까요? L2 회귀는 w2를 선호하죠. 왜냐면 더 작은 놈이니까요. 답은 L2 회귀는 분류기의 복잡도를 이런 비교적 조잡한 방식으로 측정합니다. L2 회귀는 그 영향을 x의 모든 값에 대해 퍼트리는 것을 선호한다고 할 수 있죠. 아마도 만약 변화하는 x를 가져온다면, 이게 더 견고할 겁니다. 우리의 결정은 넓게 펼쳐지고, x벡터의 하나의 어떤 원소에 의존하기 보다는 전체 x 벡터에 의존하죠. 그런데, L1은 그 반대의 해석을 할 수 있습니다. 만약 우리가 L1 정규화를 쓴다면, 우리는 사실 w1을 w2보다 선호하는 거죠. 왜냐면, L1 정규화는 다른 복잡도 개념을 가져서, 아마 그 모델이 덜 복잡하다고 말하고, 아마도 우리는 모델의 복잡도를 가중치 벡터에서 0의 개수로 측정한다고 말하는 거죠. L2가 어떻게 복잡도를 측정하나요? 이건 문제에 따라 다른데, 여러분은 특정 셋업에 대해 생각하고, 특정 모델과 데이타에 대해 생각해야 합니다. 이 작업에서 복잡도가 어떻게 측정되어야 한다고 생각하나요?
만약 여러분이 하드코어 베이지안 (baysian)이라면, L2 정규화 사용을 파라미터 벡터에 대한 가우시안 전제 (Gaussian prior) 하의 맵 추론( MAP inference)이라고 해석할 수 있죠. 이게 저의 멀티클래스 SVM 손실에 대한 긴 딥 다이브 (deep dive)였습니다.
우리는 멀티클래스 SVM손실을 보았는데요. 사이드 노트 (side note)로, 이건 여러 클래스에 대한 SVM 손실의 확장 혹은 일반화구요. 사실 몇 가지 여러분이 책에서 볼 수 있는 다른 공식들이 있죠 . 제 직관으론 실제에선 비슷하게 동작합니다. 적어도 딥러닝에서는요.
물론 생각할 수 있는 다양한 손실 함수들이 있습니다. 멀티클래스 SVM외에 딥러닝에서 많이 쓰는 것은, 다항식 로지스틱 회귀 (multinomial logistic regression)죠. 혹은 소프트맥스 (softmax) 손실이죠. 딥러닝 문맥에선 이게 사실 좀 더 흔하지만, 몇가지 이유로 이걸 두번째로 소개하기로 했어요. 멀티클래스 SVM 손실에서는 이 점수에 대해 해석을 하지 않았습니다. 우리가 분류를 할 때, 우리의 모델 f는 클래스에 대한 점수들인 10개의 숫자를 내 뱉고, 멀티클래스 SVM에선 이 점수에 대해선 별로 해석을 하지 않았죠. 우린 그냥 참인 점수를 원한다고 했고, 맞는 클래스의 점수가 틀린 클래스보다 커야 한다고 했습니다. 그 이상으론 이 점수들이 뭘 의미하는진 얘기하지 않았구요.
그러나 다항식 로지스틱 회귀 손실 함수에선, 이 점수들에 추가적인 의미를 부여합니다.
특히, 우리의 클래스에 대해 확률 분포를 계산하기 위해서 이 점수들을 사용할 겁니다.
우리는 소위 소프트맥스 (softmax) 함수를 쓸 거고, 우리의 모든 점수를 얻을 건데요. 우리는 그걸 지수화해서 그럼 양수로 만듭니다. 이 지수들의 합으로 그것들을 다시 정규화 합니다. 이 소프트맥스 함수로 우리가 점수들을 통과시키면, 우리는 이 확률 분포를 얻게 됩니다. 여기서 우리는 클래스들에 대한 확률을 얻는데, 확률은 0부터 1사이의 수입니다. 그리고 모든 클래스에 걸친 확률의 합은 1이죠.
이건 우리의 점수가 의미하는 계산된 확률 분포인데, 우리는 이걸 타겟 혹은 진짜 확률 분포 비교하고 싶습니다. 만약 우리가 어떤 것이 고양이라는 걸 안다면, 목표 확률 분포는 모든 확률 질량을 고양이에 놓을 겁니다. 그럼 우리는 고양이의 확률이 1이라고 얻을 겁니다. 다른 클래스들은 모두 0이구요. 우리가 이제 하고 싶은 건 우리의 계산된 확률 분포가 이 소프트맥스 함수로부터 나오게 하는 겁니다. 모든 질량을 가지는 이 맞는 클래스에 대한 목표 확률 분포가 나오도록요. 이 방정식을 여러 방면으로 할 수 있는데, 타겟과 계산된 확률 분포 사이의 KL 발산 (KL divergence)으로 풀 수도 있고, 또 최대 우도 추정으로 할 수도 있죠. 마지막엔, 우리가 원하는건 이 참 클래스에 대한 확률이 높아서 1에 가까워지는 것입니다. 그럼 우리의 손실은 참인 클래스의 확률의 음수 로그겠죠.
기억할 건 우리는 확률이 1에 가깝길 원한다는 거죠. 로그는 계속 증가하는 함수인데, 수학적으로는 로 (raw) 확률을 최대화하는 것보다 로그를 최대화하는 것이 쉽다고 알려졌죠. 그러니 계속 로그를 쓰죠. 로그는 계속 증가하고, 만약 맞는 클래스의 P의 로그 값을 최대화화면, 즉, 그 말은 그것이 높길 바란다는 거고, 그러나 손실함수는 나쁜 걸 측정하는 거지 좋은 걸 측정하는 게 아니죠. 그래서 우리는 -1을 붙여서 바른 방향으로 가도록 합니다. 우리의 손실 함수 SVM은 참인 클래스 확률의 마이너스 로그입니다.
이건 그 요약이죠. 우리는 점수를 취해서 소프트맥스를 거치고, 우리의 손실은 이 참인 클래스의 확률의 마이너스 로그죠.
구체적인 예에서 어떻게 생겼는지 보기 위해, 3개의 예제와 함께 다시 아름다운 고양이로 가보죠. 우리는 선형 분류기로부터 나온 이 3개의 점수가 있죠. 이 점수들은 SVM 손실에서와 같습니다. 그러나 이제 이 점수들을 그냥 손실 함수에 넣지 말구요.
그것들을 모두 지수화해서 모두 양수로 만듭니다.
이걸 정규화해서 모두 더하면 1이 되도록 합니다.
그럼 우리 손실은 참인 클래스의 음수 로그가 되죠. 이게 소프트맥스 손실입니다. 또는 다항 로지스틱 회귀라고 합니다. 우리는 멀티클래스 SVM에 대해 직관을 얻도록 몇가지 질문을 했었는데, 소프트맥스 손실와 대조해 보기 위해 같은 질문 몇 개를 생각해 보면 좋을 것 같네요.
소프트맥스 손실의 최대값과 최소값은 뭔가요? 로그와 여러가지가 섞여 있어서 잘 모를 수도 있는데, 답은 최소 손실은 0이고, 최대 손실은 무한대죠. 우리가 원하는 확률 분포는 맞는 클래스에는 1, 틀린 클래스에는 0이죠. 맞는 경우면, 로그 안의 것이 1이되고, 맞는 클래스의 확률의 로그라서, 결국 로그 1은 0이 됩니다. -가 붙어도 0이죠. 즉, 우리가 완전히 다 맞으면, 우리 손실은 0이죠. 그러나 모든 것을 다 맞기 위해서는 우리의 점수가 어때야 할까요? 점수는 꽤 극단적으로 무한대를 향해서 가야 합니다. 우리가 이 지수화와 이 정규화를 가지고 있기 때문에, 우리가 사실 0과 1 확률 분포를 가질 수 있는 유일한 방법은 무한대의 점수를 맞는 클래스에 주고, 마이너스의 무한대 점수를 틀린 클래스에 주는 것이죠. 컴퓨터는 무한대를 잘 쓰지 못하는데, 그래서 실제로는 아마 0 손실은 얻지 못할 겁니다. 유한한 정밀도로는 말이죠. 그러나 0은 이론적인 최소 손실이라고 해석해 볼 수 있고 최대 손실은 범위가 없습니다. 만약 우리가 맞는 클래스에 대해 0 확률 질량을 가진다면, 그럼 여러분은 0의 마이너스 로그를 갖게 되죠. 로그 0은 음의 무한대고, 음수 로그 0은 양의 무한대죠. 매우 안좋네요. 그러나 여러분은 이걸 보진 못할 겁니다. 왜냐면 이 확률이 0이되는 유일한 방법은 맞는 클래스 점수 제곱이 0이면, 이건 단지 맞는 클래스 점수가 음의 무한대일 때 가능합니다. 그래서 다시 말하면 여러분은 유한한 정확도로는 이 최소값과 최대값을 가질 수 없을 겁니다.
기억할건 우리는 멀티클래스 SVM 문맥에서 디버깅 온전성 검사 질문인데, 소프트맥스에 대해서도 같은 질문을 할 수 있습니다.
만약 모든 s가 작다면, 즉 약 0이면, 손실은 뭘까요? - log (1/C)죠. 로그는 플리핑 (flipping)할 수 있으니 그냥 log C죠. 이건 좋은 디버깅인데, 만약 소프트맥스 손실로 모델을 훈련한다면, 먼저 첫번째 반복을 확인해 봐야 합니다. 이게 log C가 아니면, 뭔가 잘못된 겁니다.
우리는 이 두 손실 함수를 비교하고 대조해 볼 수 있는데요. 선형분류에 관한 이 셋업은 똑같아 보입니다. 우리는 입력에 대해 곱해지는 W 매트릭스를 갖고 있고, 그걸로 이 점수 벡터를 만들 수 있죠. 이 두 손실 함수간의 차이는 우리가 나중에 정량적으로 나쁨을 측정하기 위해서 어떻게 그 점수들을 해석하는 지로 결정됩니다. SVM에 대해서는, 우리는 맞는 클래스 점수와 틀린 클래스 점수 간의 마진을 볼 거고, 소프트맥스 혹은 교차 엔트로피 손실에 대해서는, 확률 분포를 계산할 겁니다. 그리곤 맞는 클래스의 음의 로그 확률을 보죠.
이 두 손실 함수를 대조할 때, 재미있는 질문이 있는데요. 이 예제 포인트를 갖는다고 해 보죠. 위 그림 아래쪽은 무시하고, 우리가 이것에 대한 3개의 점수를 갖는다면, 우리가 이전에 본 예제로 다시 돌아가면 , 멀티클래스 SVM 손실에서, 차가 있었고, 차 점수는 다른 틀린 클래스 보다 훨씬 점수가 높았죠. 그 차 이미지의 점수를 약간 바꿔보는 건 멀티클래스 SVM 손실을 전혀 바꾸지 않았죠. 왜냐면 SVM 손실이 신경 쓰는 건 틀린클래스 점수와 비교해서 마진보다 큰 맞는 점수를 얻는 것이죠. 그러나 소프트맥스 손실은 이점에서 꽤 다릅니다. 소프트맥스 손실은 여러분이 맞는 클래스에 대해서 매우 높은 점수를 주더라도 항상 확률 질량을 1로 몰아가길 원하죠. 틀린 클래스에 대해 아주 낮은 확률을 주고요. 소프트맥스는 여러분이 점점 더 많은 확률 질량을 맞는 클래스 위에 쌓길 원할 겁니다. 맞는 클래스의 그 점수를 무한대를 향해 위로 밀어붙이죠. 그리고 틀린 클래스의 점수는 아래로 음의 무한대로요. 이게 실제에서 두 손실함수의 재밌는 차이죠. SVM은 이 데이타 포인트를 바 (bar) 위로 올려 맞는 클래스로 분류되도록 한 후엔, 포기하죠. 더이상 데이타 포인트를 상관하지 않습니다. 반면 소프트맥스는 계속해서 모든 데이타 포인트를 개선해서 더 나아지도록 노력합니다. 이게 이 두 함수의 재밌는 차이죠. 실제로는, 뭘 선택하든지 별로 큰 차이가 없고, 적어도 많은 딥러닝 어플리케이션에서는 둘은 꽤 비슷하게 동작합니다. 하지만 이 둘의 차이를 알고 있는건 유용하죠.
우리가 배운 데까지 정리해 보면, 우리는 x들과 y들의 데이타 셋이 있고, 우리는 우리의 선형 분류기를 사용해서 우리의 입력인 x로 부터 점수 s를 계산하는 어떤 점수 함수를 얻고나서 우리는 손실 함수를 사용할 건데요. 아마도 우리의 예측이 그라운드 참 타겟 (ground true target) y와 비교해서 정량적으로 얼마나 나쁜지 계산하기 위해 소프트맥스, SVM이나 혹은 다른 어떤 다른 손실 함수를 사용할 겁니다. 그리고 우리는 종종 이 손실 함수를 정규화 항으로 개선하죠. 그 항은 학습 데이타에 핏 (fit)하는 것과 간단한 모델을 선호하는 것 사이에서 트래이드오프 (trade-off)하려고 합니다. 이건 우리가 흔히 지도 학습이라고 부르는 것의 꽤 일반적인 개요죠. 우리가 앞으로 딥 러닝에서 볼 것은, 일반적으로 구조상 매우 복잡한 함수 f를 지정하고 싶을 거고, 여러분의 알고리즘이 주어진 어떤 파라미터에 대해서 얼마나 잘 하는지 결정하는 어떤 손실함수를 지정하죠. 모델 복잡도에 패널티를 주는 어떤 정규화 항도 있구요. 이 모든걸 조합해서 최종 손실함수를 최소화하는 W를 찾으려고 하죠.
그 다음 질문은, 우리는 실제로 어떻게 그걸 하죠? 어떻게 이 손실을 최소화하는 W를 찾죠? 그것이 우리를 최적화라는 주제로 이끕니다.
'AI' 카테고리의 다른 글
[cs231n] 3강 손실 함수와 최적화 (4/4, 경사하강 / Gradient Descent) (0) 2021.01.12 [cs231n] 3강 손실 함수와 최적화 (3/4, 최적화 / optimization) (0) 2021.01.11 [cs231n] 3강 손실 함수와 최적화 (1/4, 멀티클래스 (multiclass) SVM) (0) 2021.01.10 [cs231n] 2강 이미지 분류 (4/4, 선형 분류기) (0) 2021.01.06 [cs231n] 2강 이미지 분류 (3/4, K-최근접 이웃/ K-Nearest Neighbors) (0) 2021.01.05