如何使用trainer.test
方法来获得所有批次的总准确度?
我知道我可以实现model.test_step
,但这只是针对单个批次。我需要整个数据集的准确性。我可以使用torchmetrics.Accuracy
来积累准确性。但是,将其结合起来并获得总体准确性的正确方法是什么?既然批量测试分数不是很有用,那么model.test_step
应该返回什么?我可以以某种方式破解它,但我很惊讶,我在互联网上找不到任何例子来证明如何用pytorch闪电原生方式获得准确性。
您可以在这里看到(https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging(,CCD_ 6中的CCD_。正确的方法是:
from torchmetrics import Accuracy
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = self.criterion(preds, y)
accuracy = Accuracy()
acc = accuracy(preds, y)
self.log('accuracy', acc, on_epoch=True)
return loss
如果需要自定义缩减函数,可以使用reduce_fx
参数进行设置,默认值为torch.mean()
。log()
可以从LightningModule
中的任何方法调用
我正在写笔记本。我对下面的代码做了一些初步的实验。
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
调用后打印出格式良好的文本
model = Cifar100Model()
trainer = pl.Trainer(max_epochs=1, accelerator='cpu')
trainer.test(model, test_dataloader)
此打印测试_acc 0.008200000040233135
我尝试验证打印的值是否实际上是测试数据批次的平均值。通过如下修改testrongtep:
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
preds = logits.argmax(dim=-1)
acc = (y == preds).float().mean()
print(acc)
然后再次运行trainer.test((。这一次打印出以下值:
张量(0.0049(
张量器(0.0078(
张量器这与testrongtep((打印的值非常接近。
这个解决方案背后的逻辑是我在中指定的
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
on_ech=True,并且我使用了TorchMetric类,平均值由PL自动使用metric.compute((函数计算。
我很快就会把我的笔记本发出来。你也可以在那里查看。