如何在pytorch框架中工作时在python类模块中修复此NotImplementedError ?



大家好,我正在使用pytorch研究CIFAR10数据集。我已经开发了一个模型,它的工作绝对很好,但主要问题发生在运行以下代码:

import time
start_time=time.time()
epochs=5
train_losses=[]
test_losses=[]
train_correct=[]
test_correct=[]
for i in range(epochs):
tsn_corr=0
tst_corr=0

for b, (X_train,y_train) in enumerate(train_loader):
b+=1

y_pred=model(X_train)
loss=criterion(y_pred,y_train)


#Tally the number of correct predictions

predicted= torch.max(y_pred.data, 1)[1]
batch_corr=(predicted==y_train).sum()
tsn_corr += batch_corr

#optimize paramters
optimizer.zero_grad()
loss.backward()
optimizer.step()



#print interim results
if b%600 == 0:
print(f"epochs: {i}, batch: {b}, loss: {loss.item():10.8f}")

loss=loss.detach().numpy()
train_losses.append(loss)
train_correct.append(tsn_corr)

#Running the test_batches

with torch.no_grad():
for b, (X_test,y_test) in enumerate(test_loader):
b+=1

y_val=model(X_test)



#TALLY THE NUMBER OF CORRECT PREDICTIONS

predicted=torch.max(y_val.data, 1)[1]
batch_corr= (predicted==y_test).sum()
tst_corr += batch_corr

loss=criterion(y_val,y_test)    
loss=loss.detach().numpy()
test_losses.append(loss)
test_correct.append(tst_corr)

在运行以下代码时发生以下错误:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-43-48e21e83e9f7> in <module>
15         b+=1
16 
---> 17         y_pred=model(X_train)
18         loss=criterion(y_pred,y_train)
19 
~Anaconda3libsite-packagestorchnnmodulesmodule.py in _call_impl(self, *input, **kwargs)
887             result = self._slow_forward(*input, **kwargs)
888         else:
--> 889             result = self.forward(*input, **kwargs)
890         for hook in itertools.chain(
891                 _global_forward_hooks.values(),
~Anaconda3libsite-packagestorchnnmodulesmodule.py in _forward_unimplemented(self, *input)
199         registered hooks while the latter silently ignores them.
200     """
--> 201     raise NotImplementedError
202 
203 
NotImplementedError: 

谁能告诉我我能做些什么来修复这个代码。除此之外,之前所有的代码都工作得很好,我使用卷积神经网络制作的模型也成功运行,这意味着模型没有问题。我想这个细节会有帮助。可能会注意到,这段代码在MNIST数据集上工作得很好。我不知道CIFAR数据集有什么问题

您的模型类需要实现一个forward方法。参见PyTorch示例中的子类化查看示例。

最新更新