본문 바로가기
인공지능/ML, DL

[PyTorch] CrossEntropyLoss가 음의 무한대로 갈 때

by judy@ 2023. 11. 23.

최근 PyTorch로 모델 쌓는 법을 배우면서 from scratch 로 구현하는 시도를 해보고 있다.

그런데, loss가 점점더 -∞로 가고 업데이트되지 않는 문제가 발생했다.

 

아주 사소해보이지만 치명적인 실수를 했다는 것을 알게 되어 얼굴을 붉히며 포스팅을 남겨본다 🥵

 

우선 하려고 했던 건 간단하다. MNIST를 flatten하여 입력으로 넣고,
2개의 hidden layer를 가지고 10개의 로짓을 출력하는 매우 간단한 모델을 만들었다.

모델의 파라미터는 he, zero로 초기화, optimizer로는 Adam을 사용하였다.

 

코드를 보자.

# loss function
loss_fn = nn.CrossEntropyLoss()

for e in range(epochs):
    for b, (data, label) in enumerate(train_dataloader):

        data, label = data.float().to(device), label.float().to(device)
        pred = model.forward(data.view((-1, 28*28)))

        optim.zero_grad()
        loss = loss_fn(label, pred) # 문제의 코드
        loss.backward()
        optim.step()

 

뭐가 문제인지 눈치챘는가? 바로 label과 pred의 순서이다. CrossEntropyLoss의 forward() 메서드의 순서를 보면, input 다음에 target을 받고 있다. 근데 나는? 거꾸로 넣었다.

# Partial source code of torch.nn.CrossEntropyLoss()
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    return F.cross_entropy(input, target, weight=self.weight,
                           ignore_index=self.ignore_index, reduction=self.reduction,
                           label_smoothing=self.label_smoothing)

 

CrossEntropyLoss는 softmax가 취해지지 않은 로짓인 pred와 원핫벡터인 label을 입력으로 받은 뒤, 로짓에 softmax를 취해 원핫벡터와의 CELoss를 구해주는 훌륭한 도우미이다.

 

하지만 이 순서를 반대로 넣어버리면 이미 원핫 벡터인 pred에 의미없이 softmax를 취할 뿐만 아니라,
ylogy^을 y^logy로 계산하고, log0은 -∞와 가까워져 loss가 음의 방향으로 매우 커지게 된다.
log1은 오히려 0이기 때문에 정답에 대해 예측한 로짓은 오히려 loss에 포함되지 않게 된다.

 

하여간 총체적 난국이 되어버리고, 결과는 이렇게..^^

[step 0] loss: -24815378490.34648 acc: 0.09945238095238096
[step 1] loss: -318719438808.0 acc: 0.09945238095238096
[step 2] loss: -1145133952800.0 acc: 0.09945238095238096
[step 3] loss: -2619462651520.0 acc: 0.09945238095238096
[step 4] loss: -4828149267456.0 acc: 0.09945238095238096
[step 5] loss: -7862875674624.0 acc: 0.09945238095238096
[step 6] loss: -11819620434432.0 acc: 0.09945238095238096
[step 7] loss: -16788731928576.0 acc: 0.09945238095238096
[step 8] loss: -22859443149824.0 acc: 0.09945238095238096
[step 9] loss: -30131526211584.0 acc: 0.09945238095238096

 

도대체 뭐가 문제일까 한참을 헤매이다 hoxy... 하는 심정으로 순서를 뙇 바꿨는데 

loss = loss_fn(pred, label) # (label, pred) 에서 변경

 

하하하하하핳하하 이렇게 괜찮은 녀석(=모델)에게 못된 짓을 하고 있었다니... 눈물이 앞을 가렸다..

[step 0] loss: 379.69621264748275 acc: 0.9560714285714286
[step 1] loss: 163.69974727369845 acc: 0.9747380952380953
[step 2] loss: 116.49066282110289 acc: 0.978
[step 3] loss: 87.3123155075009 acc: 0.9848809523809524
[step 4] loss: 72.63109205308137 acc: 0.9911666666666666
[step 5] loss: 56.055692304333206 acc: 0.9843571428571428
[step 6] loss: 44.11998969452543 acc: 0.9924761904761905
[step 7] loss: 37.779681337553484 acc: 0.9954761904761905
[step 8] loss: 31.468174697056384 acc: 0.993452380952381
[step 9] loss: 29.603860789306054 acc: 0.9950952380952381

 

이 글을 읽는 누군가는.. 이런 실수를 하지 않길 바라며.... 글을 마친다..

반응형