pytorch文档中的"定义模型类"是什么意思



在pytorch关于保存和加载模型的文档页面上,它说当加载保存的模型时,# Model class must be defined somewherehttps://pytorch.org/tutorials/beginner/saving_loading_models.html#:~:text=%23%20Model%20class%20must%20be%20defined%20somewhere

也许我的问题很傻,但在这种情况下,class指的是什么?提前谢谢。

在页面的早些时候,"模型加载过程"被描述为

Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

该上下文中的类指的是您试图用torch.load加载的模型的类。必须定义类,因为该函数将使用存储在PATH中的模型类名来构造模型对象。因此,如果在执行torch.load之前没有定义具有该名称的类,则构造将失败。这个过程类似于pickle加载.pkl文件的方式(事实上,我认为torch.load默认使用pickle(。

请注意,如果保存和加载模型的状态dict(推荐的方式(,则不需要模型类定义,因为状态dict是以字符串为键、以torch.Tensor为值的Python dict。字典和字符串是内置的,因此它们总是被定义的,并且每当您导入torch以使用torch.load时,torch.Tensor总是被定义。

您需要定义模型类,例如,如本文所述。再次使用链接网站的示例作为随机示例,TheModelClass的类可以定义如下:

class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.linear1 = torch.nn.Linear(100, 200)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(200, 10)
self.softmax = torch.nn.Softmax()
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
x = self.softmax(x)
return x

相关内容

最新更新