如何更改 Pytorch CNN 以拍摄彩色图像而不是黑白图像?



我发现的这段代码有一个神经网络,可以拍摄黑白图像。(这是一个暹罗网络,但那部分无关紧要)。当我将其更改为拍摄图像而不将它们转换为黑白时,我收到如下所示的错误。
我尝试将第一个 Conv2d,第六行从 1 更改为 3

class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(
nn.ReflectionPad2d(1),
# was nn.Conv2d(1, 4, kernel_size=3),
nn.Conv2d(3, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8))
self.fc1 = nn.Sequential(
nn.Linear(8*300*300, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5))
def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2

当图像未转换为黑白并保持彩色时,我的错误。

RuntimeError: invalid argument 0: Sizes of tensors must match  
except in dimension 0. Got 3 and 1 in dimension 1 at  
/opt/conda/conda-bld/pytorch-nightly_1542963753679/work/aten/src/TH/generic/THTensorMoreMath.cpp:1319

我将图像的形状检查为数组(就在它们进入模型之前)为黑白与彩色......

B&W

torch.Size([1, 1, 300, 300])

在彩色

torch.Size([1, 3, 300, 300])

这是指向我正在使用的整个原始代码的Jupyter笔记本的链接...https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb

编辑:更新:我似乎通过在代码的SiameseNetworkDataset部分中将图像转换为RBG来解决它

img0 = img0.convert("L")

改为

img0 = img0.convert("RGB")

我之前只是注释掉了这条线,并认为这在 RGB 中留下了它,但这是模型不理解的其他内容。 此外,还需要更改 OP

nn.Conv2d(1, 4, kernel_size=3),

改为

nn.Conv2d(3, 4, kernel_size=3),

如果你想回答一个解释模型正在做什么,清楚地表明我会给你绿色的检查。不太懂。卷积2d

错误似乎在下面的完全连接部分:

self.fc1 = nn.Sequential(
nn.Linear(8*100*100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5))

似乎CNN的输出是形状[8,300,300]而不是[8,100,100]

要解决此问题,请将输入图像更改为[n_channel, 100,100]或将 fc 层的输入大小更改为8*300*300

最新更新