我有一个带有预先训练的resnet字段的类模型类似于:
class A(nn.Module):
def __init__(self, **kwargs):
super(A, self).__init__()
self.resnet = get_resnet()
def forward(self, x):
return self.resnet(x)
...
现在我在做
model = A()
...
model.eval()
可以还是应该重写eval
、train
函数?
简短回答
没事。
答案很长
由于nn.Module.train()
像这样递归运行。
self.training = mode
for module in self.children():
module.train(mode)
return self
而nn.Module.eval()
只是在调用self.train(False)
只要self.resnet
是nn.Module
的子类。您不需要为此烦恼,实际上nn.Module
中除forward
之外的所有方法都会影响所有子模块。
你可以通过测试
model = A()
...
model.eval()
print(model.resnet.training) # should be False
如果你得到False
,那么一切都很好。如果你得到了其他东西,那么get_resnet()
就有问题了。