如何获得pytorch闪电的总测试精度



如何使用trainer.test方法来获得所有批次的总准确度?

我知道我可以实现model.test_step,但这只是针对单个批次。我需要整个数据集的准确性。我可以使用torchmetrics.Accuracy来积累准确性。但是,将其结合起来并获得总体准确性的正确方法是什么?既然批量测试分数不是很有用,那么model.test_step应该返回什么?我可以以某种方式破解它,但我很惊讶,我在互联网上找不到任何例子来证明如何用pytorch闪电原生方式获得准确性。

您可以在这里看到(https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging(,CCD_ 6中的CCD_。正确的方法是:

from torchmetrics import Accuracy
def validation_step(self, batch, batch_idx): 
x, y = batch 
preds = self.forward(x) 
loss = self.criterion(preds, y) 
accuracy = Accuracy()
acc = accuracy(preds, y)
self.log('accuracy', acc, on_epoch=True)
return loss 

如果需要自定义缩减函数,可以使用reduce_fx参数进行设置,默认值为torch.mean()log()可以从LightningModule中的任何方法调用

我正在写笔记本。我对下面的代码做了一些初步的实验。

def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)

调用后打印出格式良好的文本

model = Cifar100Model()
trainer = pl.Trainer(max_epochs=1, accelerator='cpu')
trainer.test(model, test_dataloader)

此打印测试_acc 0.008200000040233135

我尝试验证打印的值是否实际上是测试数据批次的平均值。通过如下修改testrongtep:

def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
preds = logits.argmax(dim=-1)
acc = (y == preds).float().mean()
print(acc)

然后再次运行trainer.test((。这一次打印出以下值:
张量(0.0049(
张量器(0.0078(
张量器这与testrongtep((打印的值非常接近。

这个解决方案背后的逻辑是我在中指定的

self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)

on_ech=True,并且我使用了TorchMetric类,平均值由PL自动使用metric.compute((函数计算。

我很快就会把我的笔记本发出来。你也可以在那里查看。

最新更新