我正在ResNet上工作,我发现了一个用加号跳过连接的实现
Class Net(nn.Module):
def __init__(self):
super(Net, self).__int_()
self.conv = nn.Conv2d(128,128)
def forward(self, x):
out = self.conv(x) // line 1
x = out + x // skip connection // line 2
现在,我已经调试并打印了第1行之前和之后的值。输出如下:
第1行之后
x=[1128,32,32]
out=[1128,32,32]第2行之后
x=[1128,32]//仍然
引用链接:https://github.com/kuangliu/pytorch-cifar/blob/bf78d3b8b358c4be7a25f9f9438c842d837801fd/models/resnet.py#L62
我的问题是它在哪里增加了价值??我是说之后
x=out+x
操作,在哪里添加了值?
PS:张量格式为[批次,通道,高度,宽度]。
正如@UmangGupta在评论中提到的,你打印的似乎是张量的形状(即3x3
矩阵的"形状"是[3, 3]
(,而不是它们的内容。在您的情况下,您正在处理1x128x32x32
张量(。
希望澄清形状和内容之间差异的示例:
import torch
out = torch.ones((3, 3))
x = torch.eye(3, 3)
res = out + x
print(out.shape)
# torch.Size([3, 3])
print(out)
# tensor([[ 1., 1., 1.],
# [ 1., 1., 1.],
# [ 1., 1., 1.]])
print(x.shape)
# torch.Size([3, 3])
print(x)
# tensor([[ 1., 0., 0.],
# [ 0., 1., 0.],
# [ 0., 0., 1.]])
print(res.shape)
# torch.Size([3, 3])
print(res)
# tensor([[ 2., 1., 1.],
# [ 1., 2., 1.],
# [ 1., 1., 2.]])