본문 바로가기
CS/삽질로그

파이토치 Input type and weight type should be the same .. 오류

by judy@ 2023. 10. 25.

오류 상황

CIFAR10 데이터 세트를 로드하고, Lenet 아키텍처를 빌드하여 학습하려는데 다음과 같은 오류가 났다.

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

 

해결 방안

데이터와 모델 모두 cuda 디바이스에 올린다. -> 오류 해결

# 미니 배치
for batch, (X, y) in enumerate(dataloader):
        X, y = X.to('cuda'), y.to('cuda')
# 모델 및 손실 함수
lenet = LeNet().to('cuda')
loss_fn = nn.CrossEntropyLoss().to('cuda')
optimizer = torch.optim.Adam(lenet.parameters(), lr=learning_rate)

 

반응형