我正在尝试使用nn.ModuleDict
遵循此文档页:
我有这个PyTorch网络:
class Net(nn.Module):
def __init__(self, kernel_size):
super(Net, self).__init__()
modules = {}
modules["layer1"] = nn.Conv2d(3, 16, kernel_size=kernel_size, stride=1, padding=2)
self.modules = nn.ModuleDict(modules)
def forward(self, x):
x = self.modules["layer1"](x)
当我使用forward方法时,我得到以下错误:
'method'对象不可下标
当我将forward方法更改为:
def forward(self, x):
x = self.modules()["layer1"](x)
我得到以下错误:
TypeError: 'generator'对象不可下标
密钥"modules"
已被nn.Module
使用。该属性用于检索模型的所有模块:参见nn.Module.modules
。您需要使用另一个属性名。
例如:
class Net(nn.Module):
def __init__(self, kernel_size):
super(Net, self).__init__()
modules = {}
modules["layer1"] = nn.Conv2d(3, 16,
kernel_size=kernel_size, stride=1, padding=2)
self.layers = nn.ModuleDict(modules)
def forward(self, x):
x = self.layers["layer1"](x)
return x