我设计了一个类,它是一个网络块,其正向有三个输入:x、logdet、reverse,并有两个输出。例如,当我调用这个类并使用它时,一切都是正常的,比如:
x = torch.Tensor(np.random.rand(2, 48, 8, 8))
net = Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
a, _ = net(x, reverse = False)
但是,当我想通过Sequential使用它时(因为我需要一个接一个的多个块(,问题发生如下:
x = torch.Tensor(np.random.rand(2, 48, 8, 8))
conv1_network = nn.Sequential(
Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
)
conv1_network(x, reverse = False)
我的错误是:TypeError: forward() got an unexpected keyword argument 'reverse'
这是不正常的,因为正如我们在第一部分中看到的,我在Block中的正向输入是反向的。我期待着找到一种将一些块相互连接的方法,例如,这是一个块
class Block(nn.Module):
def __init__(self, num_channels):
super(InvConv, self).__init__()
self.num_channels = num_channels
# Initialize with a random orthogonal matrix
w_init = np.random.randn(num_channels, num_channels)
w_init = np.linalg.qr(w_init)[0].astype(np.float32)
self.weight = nn.Parameter(torch.from_numpy(w_init))
def forward(self, x, logdet, reverse=False):
ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
if reverse:
weight = torch.inverse(self.weight.double()).float()
logdet = logdet - ldj
else:
weight = self.weight
logdet = logdet + ldj
weight = weight.view(self.num_channels, self.num_channels, 1, 1)
z = F.conv2d(x, weight)
return z, logdet
我的目的是在for中以Sequential将多个块相互连接(因为我在工作中不能使用同一个块,所以我需要不同的卷积来制作深度网络(
features = []
for i in range(10):
self.features.append(Block(num_channels = 48))
然后我想像这个一样使用它们
self.features(x, logdet = 0, reverse = False)
您指出您的Block
nn.Module
具有reverse
选项。然而nn。Sequential没有,所以conv1_network(x, reverse=False)
无效,因为conv1_network
不是Block
。
默认情况下,不能将kwargs
传递给nn.Sequential
内部的层。但是,您可以从nn.Sequential
继承并自己完成。类似于:
class BlockSequence(nn.Sequential):
def forward(self, input, **kwargs):
for module in self:
options = kwargs if isinstance(module, Block) else {}
input = module(input, **options)
return input
通过这种方式,您可以创建一个包含Block
s(以及可选的非Block
模块(的序列:
>>> blocks = []
>>> for i in range(10):
... self.blocks.append(Block(num_channels=48))
>>> blocks = BlockSequence(*blocks)
然后,您将能够使用reverse
关键字参数调用blocks
,当调用时,该参数将中继到每个潜在的Block
子模块:
>>> blocks(x, logdet=0, reverse=False)