微调PyTorch:没有深度拷贝



关于在PyTorch中微调cnn,根据保存和加载模型:

如果您只计划保留最佳性能的模型(根据获得的验证损失),那么您必须序列化best_model_state或使用best_model_state = deepcopy(model.state_dict()),否则您的最佳best_model_state将在随后的训练迭代中不断更新。因此,最终的模型状态将是过拟合模型的状态。

然而,我做了这样的事情:

def train_model(model, ...):
...
if validation_loss improves:
delete previous best model
torch.save(model.state_dict(), best_model_path)
else:
....
...
return model
def test_model(model, best_model_path, ...):
model.load_state_dict(torch.load(best_model_path))
model.eval()
...
...
my_model = train_model(my_model, ...)
test_model(my_model, my_path, ...)

换句话说,训练阶段返回的模型是最终可能出现过拟合的模型(我没有使用deepcopy)。但是由于我在训练期间保存了最好的模型,所以在测试/推理阶段我没有问题,因为我加载了最好的模型,重载了测试期间获得的最终模型。

这个解决方案有问题吗?

谢谢。

您仍然遵循教程的说明。注意教程的这一部分:

必须序列化best_model_state或使用best_model_state = deepcopy(model.state_dict())

您序列化了最佳模型的状态(将其写入磁盘),因此您不需要使用deepcopy

如果您将模型保存在内存中,您将使用deepcopy来确保它在训练期间不会被更改。但是因为你把它保存在磁盘上,所以它不会被修改。

最新更新