在谷歌上搜索这个问题没有任何帮助,所以我决定把这个问题作为一个可搜索的问题发布出来,以帮助未来的我和其他人。
def __init__():
...
self.val_acc = pl.metrics.Accuracy()
def validation_step(self, batch, batch_index):
...
self.val_acc.update(log_probs, label_batch)
为
ValueError: preds and target must have same number of dimensions, or one additional dimension for preds
对于log_probs.shape == (16, 4)
和label_batch.shape == (16, 4)
有什么问题吗?
pl.metrics.Accuracy()
期望一批dtype=torch.long
标签,而不是一个热编码标签。
因此,应该输入
self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))
这与torch.nn.CrossEntropyLoss