Pytorch TypeError语言 - eq() 收到了无效的参数组合



我正在研究BERT的文本分类问题。在本地机器上训练时,一切正常,但是切换到服务器时,出现以下错误:

<ipython-input-28-508d35ac5f5f> in flat_accuracy(preds, labels)
5     pred_flat = np.argmax(preds, axis=1).flatten()
6     labels_flat = labels.flatten()
----> 7     return np.sum(pred_flat == labels_flat) / len(labels_flat)
8 
9 # Function to calculate the f1_score of our predictions vs labels
TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
* (Tensor other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)
* (Number other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)

法典:

def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)

本地机器上的火炬版本:1.4.0

服务器上的火炬版本:1.3.1

任何帮助将不胜感激!

可能是服务器上火炬版本的eq实现不再允许您在torch.Tensornp.ndarray之间进行元素比较。你应该强迫pred_flat成为torch.Tensor,或者强迫labels_flat成为 numpy 数组。由于您在 return 语句中使用np.sum并且您只是返回一个标量值,所以我只是将所有内容移动到 numpy,所以

labels_flat = labels.numpy()

但是如果你在GPU上,你可能需要调用labels.cpu().numpy(),如果你正在跟踪标签上的梯度,你可能需要labels.detach().cpu().numpy()

最新更新