torch.nn.给出输入时设计块的顺序问题



我设计了一个类,它是一个网络块,其正向有三个输入:x、logdet、reverse,并有两个输出。例如,当我调用这个类并使用它时,一切都是正常的,比如:

x = torch.Tensor(np.random.rand(2, 48, 8, 8))
net = Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
a, _ = net(x, reverse = False)

但是,当我想通过Sequential使用它时(因为我需要一个接一个的多个块(,问题发生如下:

x = torch.Tensor(np.random.rand(2, 48, 8, 8))
conv1_network = nn.Sequential(
Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
)
conv1_network(x, reverse = False)

我的错误是:TypeError: forward() got an unexpected keyword argument 'reverse'这是不正常的,因为正如我们在第一部分中看到的,我在Block中的正向输入是反向的。我期待着找到一种将一些块相互连接的方法,例如,这是一个块

class Block(nn.Module):
def __init__(self, num_channels):
super(InvConv, self).__init__()
self.num_channels = num_channels
# Initialize with a random orthogonal matrix
w_init = np.random.randn(num_channels, num_channels)
w_init = np.linalg.qr(w_init)[0].astype(np.float32)
self.weight = nn.Parameter(torch.from_numpy(w_init))
def forward(self, x, logdet, reverse=False):
ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
if reverse:
weight = torch.inverse(self.weight.double()).float()
logdet = logdet - ldj
else:
weight = self.weight
logdet = logdet + ldj
weight = weight.view(self.num_channels, self.num_channels, 1, 1)
z = F.conv2d(x, weight)
return z, logdet

我的目的是在for中以Sequential将多个块相互连接(因为我在工作中不能使用同一个块,所以我需要不同的卷积来制作深度网络(

features = []
for i in range(10):
self.features.append(Block(num_channels = 48))

然后我想像这个一样使用它们

self.features(x, logdet = 0, reverse = False)

您指出您的Blocknn.Module具有reverse选项。然而nn。Sequential没有,所以conv1_network(x, reverse=False)无效,因为conv1_network不是Block

默认情况下,不能将kwargs传递给nn.Sequential内部的层。但是,您可以从nn.Sequential继承并自己完成。类似于:

class BlockSequence(nn.Sequential):
def forward(self, input, **kwargs):
for module in self:
options = kwargs if isinstance(module, Block) else {}
input = module(input, **options)
return input

通过这种方式,您可以创建一个包含Blocks(以及可选的非Block模块(的序列:

>>> blocks = []
>>> for i in range(10):
...     self.blocks.append(Block(num_channels=48))
>>> blocks = BlockSequence(*blocks)

然后,您将能够使用reverse关键字参数调用blocks,当调用时,该参数将中继到每个潜在的Block子模块:

>>> blocks(x, logdet=0, reverse=False)

最新更新