pytorch闪电中的"stage is None"是什么时候



一些官方pytorch闪电文档的代码将stage称为Optional[str],例如以下代码

import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms

class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)

阶段什么时候取None的值?我找不到描述这一点的文档。

Trainer永远不会将stage=None发送到设置挂钩,或接受此参数的任何其他挂钩。由于历史原因,该类型被注释为可选,默认值为None。这些值将总是"0"中的一个;"适合"验证"测试"预测";。

有一个RFC将其更改为必需的参数,以避免混淆。该链接提供了更多的背景信息,说明过去为什么会这样。

报价来源:

此方法需要一个stage参数。它用于分离设置CCD_ 5的逻辑。如果使用调用setupstage=None,我们假设所有阶段都已设置。

下面是一个简单的代码示例;如果您使用这个LightningDataModule代码段:

class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 256 if torch.cuda.is_available() else 64):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage):
print(stage)

运行时:

MNIST_dm = MNISTDataModule(PATH_DATASETS)
trainer.fit(mnist_model, MNIST_dm)
trainer.validate(mnist_model, MNIST_dm)
trainer.test(mnist_model, MNIST_dm)
MNIST_dm.setup(stage="None")

你会看到这个打印出来的:

TrainerFn.FITTING
TrainerFn.VALIDATING
TrainerFn.TESTING
None

当您将其显式设置为None时,它是None,否则它将采用称为的阶段名称

相关内容

最新更新