为什么这个输入不通过这个简单的 PyTorch 模型?



我有一个输入张量,它的形状是:

torch.Size([256, 3, 28, 28])

(批量大小为256,3通道,28x28图像)

和这样的模型:

class Model(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(3, 28, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(28, 56, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),  # output: 56 x 16 x 16
nn.Conv2d(56, 112, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(112, 112, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),  # output: 112 x 8 x 8
nn.Conv2d(112, 224, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(224, 224, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),  # output: 224 x 4 x 4
nn.Flatten(),
nn.Linear(224 * 4 * 4, 896),
nn.ReLU(),
nn.Linear(896, 512),
nn.ReLU(),
nn.Linear(512, 2))
def forward(self, xb):
return self.network(xb)

当我尝试向前传递数据时,它失败了:

...
return self.network(xb)
File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

我错过了什么?

谢谢!

nn。MaxPool2d(2,2), # output: 56 × 16 × 16

这是错误的。原始输入的大小为(256,3,28,28)。你使用的卷积层和ReLU层不会改变批处理、高度或宽度维度;他们只是换了"频道"而已;维度。在最大池化层之前,张量大小为(256,56,28,28)。最大池化层的内核大小为2,步幅为2,因此它将高度和宽度都减半。因此,这个最大池化层的输出大小为(256,56,14,14)。

出于同样的原因,下一个最大池化层的输出大小为(256,112,7,7),最后一个最大池化层的输出大小为(256,224,3,3)。

所以你可以通过将输入大小更改为(256,3,32,32)来解决这个问题,如果可以的话,或者将第一个线性层更改为nn.Linear(224 * 3 * 3, 896)

最新更新