我正在研究BERT的文本分类问题。在本地机器上训练时,一切正常,但是切换到服务器时,出现以下错误:
<ipython-input-28-508d35ac5f5f> in flat_accuracy(preds, labels)
5 pred_flat = np.argmax(preds, axis=1).flatten()
6 labels_flat = labels.flatten()
----> 7 return np.sum(pred_flat == labels_flat) / len(labels_flat)
8
9 # Function to calculate the f1_score of our predictions vs labels
TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
* (Tensor other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)
* (Number other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)
法典:
def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
本地机器上的火炬版本:1.4.0
服务器上的火炬版本:1.3.1
任何帮助将不胜感激!
可能是服务器上火炬版本的eq
实现不再允许您在torch.Tensor
和np.ndarray
之间进行元素比较。你应该强迫pred_flat
成为torch.Tensor
,或者强迫labels_flat
成为 numpy 数组。由于您在 return 语句中使用np.sum
并且您只是返回一个标量值,所以我只是将所有内容移动到 numpy,所以
labels_flat = labels.numpy()
但是如果你在GPU上,你可能需要调用labels.cpu().numpy()
,如果你正在跟踪标签上的梯度,你可能需要labels.detach().cpu().numpy()
。