如何创建模块列表列表



可以创建一个 PyTorch 模块列表的 python 列表吗? 例如,如果我想在一个层中有几个 Conv1d,然后是另一个具有不同 Conv1d 的层。在每一层中,我需要根据层数对输出进行不同的操作。 构建这个模块列表的"python列表"的正确方法是什么?

这边:

class test(nn.Module):
def __init__(...):
self.modulelists = []
for i in range(4):
self.modulelists.append(nn.ModuleList([nn.Conv1d(10, 10, kernel_size=5) for _ in range(5)]))

或者这样:

class test(nn.Module):
def __init__(...):
self.modulelists = nn.ModuleList()
for i in range(4):
self.modulelists.append(nn.ModuleList([nn.Conv1d(10, 10, kernel_size=5) for _ in range(5)]))

谢谢

您需要正确注册网络的所有子模块,以便 pytorch 可以访问它们的参数、缓冲区等。

如果你将子模块存储在一个简单的pythonic列表中,pytorch将不知道那里有子模块,它们将被忽略。

因此,如果您使用简单的pythonic列表来存储子模块,例如,当您调用model.cuda()时,列表中的子模块的参数将不会传输到GPU,而是保留在CPU上。如果调用model.parameters()将所有可训练参数传递给优化器,pytorch 不会检测到所有子模块参数,因此优化器不会"看到"它们。

最新更新