PyTorch闪电在整个时代平均指标吗?



我正在查看PyTorch-Lightning官方文档https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html中提供的示例。

这里的损失和度量是在混凝土批次上计算的。但是,当进行日志记录时,人们对特定批的准确性不感兴趣,这些批可能相当小且不具有代表性,而是对所有历元的平均值感兴趣。我是否理解正确,有一些代码对所有批次执行平均,通过epoch?

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return pl.TrainResult(loss)
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({'val_acc': acc, 'val_loss': loss})
return result
def test_step(self, batch, batch_idx):
result = self.validation_step(batch, batch_idx)
result.rename_keys({'val_acc': 'test_acc', 'val_loss': 'test_loss'})
return result
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)

如果您想要在epoch内平均度量,您需要告诉您已子类化的LightningModule这样做。有几种不同的方法可以做到这一点,例如:

  1. 调用result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True),如on_epoch=True文档中所示,以便在整个历元中平均训练损失。例如:
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(loss)
result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return result
  1. 或者,您可以在LightningModule本身上调用log方法:self.log("train_loss", loss, on_epoch=True, sync_dist=True)(可选择传递sync_dist=True以减少跨加速器)。

您将希望在validation_step中做类似的事情以获得聚合的值集度量,或者在validation_epoch_end方法中自己实现聚合。

最新更新