通过模型检查点时Pytorch闪电出现错误



我正在使用拥抱脸模型训练一个多标签分类问题。我正在使用Pytorch Lightning来训练模型。

代码如下:

当损失没有改善时提前停止触发

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

我们可以开始训练过程:

checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min"
)

trainer = pl.Trainer(
logger=logger,
callbacks=[early_stopping_callback],
max_epochs=N_EPOCHS,
checkpoint_callback=checkpoint_callback,
gpus=1,
progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,

一旦我运行这个,我得到这个错误:

~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
75             if isinstance(checkpoint_callback, Callback):
76                 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77             raise MisconfigurationException(error_msg)
78         if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
79             raise MisconfigurationException(
MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.

如何解决这个问题?

您可以在pl.Trainer的文档页面中查找checkpoint_callback参数的描述:

checkpoint_callback(bool)—如果是True,开启检查点。如果回调中没有用户自定义的ModelCheckpoint,它将配置默认的ModelCheckpoint回调。

你不应该把你的自定义ModelCheckpoint传递给这个参数。我相信你想要做的是在callbacks列表:

中传递EarlyStoppingModelCheckpoint
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min")
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback],
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30)

最新更新