Pytorch resNet在哪里增值



我正在ResNet上工作,我发现了一个用加号跳过连接的实现

Class Net(nn.Module):
def __init__(self):
super(Net, self).__int_() 
self.conv = nn.Conv2d(128,128)
def forward(self, x):
out = self.conv(x) // line 1 
x = out + x    // skip connection  // line 2

现在,我已经调试并打印了第1行之前和之后的值。输出如下:

第1行之后
x=[1128,32,32]
out=[1128,32,32]

第2行之后
x=[1128,32]//仍然

引用链接:https://github.com/kuangliu/pytorch-cifar/blob/bf78d3b8b358c4be7a25f9f9438c842d837801fd/models/resnet.py#L62

我的问题是它在哪里增加了价值??我是说之后

x=out+x

操作,在哪里添加了值?

PS:张量格式为[批次,通道,高度,宽度]。

正如@UmangGupta在评论中提到的,你打印的似乎是张量的形状(即3x3矩阵的"形状"是[3, 3](,而不是它们的内容。在您的情况下,您正在处理1x128x32x32张量(。

希望澄清形状和内容之间差异的示例:

import torch
out = torch.ones((3, 3))
x = torch.eye(3, 3)
res = out + x
print(out.shape)
# torch.Size([3, 3])
print(out)
# tensor([[ 1.,  1.,  1.],
#         [ 1.,  1.,  1.],
#         [ 1.,  1.,  1.]])
print(x.shape)
# torch.Size([3, 3])
print(x)
# tensor([[ 1.,  0.,  0.],
#         [ 0.,  1.,  0.],
#         [ 0.,  0.,  1.]])
print(res.shape)
# torch.Size([3, 3])
print(res)
# tensor([[ 2.,  1.,  1.],
#         [ 1.,  2.,  1.],
#         [ 1.,  1.,  2.]])

最新更新