Pytorch lightning metrics: ValueError: preds和target必须具有相同数量的



在谷歌上搜索这个问题没有任何帮助,所以我决定把这个问题作为一个可搜索的问题发布出来,以帮助未来的我和其他人。


def __init__():
...
self.val_acc = pl.metrics.Accuracy()
def validation_step(self, batch, batch_index):
...
self.val_acc.update(log_probs, label_batch)

ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

对于log_probs.shape == (16, 4)label_batch.shape == (16, 4)

有什么问题吗?

pl.metrics.Accuracy()期望一批dtype=torch.long标签,而不是一个热编码标签。

因此,应该输入

self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))


这与torch.nn.CrossEntropyLoss

相同

相关内容

  • 没有找到相关文章

最新更新