model.eval,适用于网络领域为pytorch的类



我有一个带有预先训练的resnet字段的类模型类似于:

class A(nn.Module):
def __init__(self, **kwargs):
super(A, self).__init__()
self.resnet = get_resnet()
def forward(self, x):
return self.resnet(x)
...

现在我在做

model = A()
...
model.eval()

可以还是应该重写evaltrain函数?

简短回答

没事。

答案很长

由于nn.Module.train()像这样递归运行。

self.training = mode
for module in self.children():
module.train(mode)
return self

nn.Module.eval()只是在调用self.train(False)

只要self.resnetnn.Module的子类。您不需要为此烦恼,实际上nn.Module中除forward之外的所有方法都会影响所有子模块。

你可以通过测试

model = A()
...
model.eval()
print(model.resnet.training)  # should be False

如果你得到False,那么一切都很好。如果你得到了其他东西,那么get_resnet()就有问题了。

最新更新