我是pytorch的新手,我正试图在pytorch中创建一个自动编码器,这是我的代码
编码器:
# B = Batch size
# encoder (B, 3, 224, 224) => (B, 8)
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder_cnn = nn.Sequential(
# input shape: (B, 3, 224, 224) =>
nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=0),
nn.ReLU(True),
nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0),
nn.ReLU(True),
nn.BatchNorm2d(16),
nn.MaxPool2d(2,return_indices=True)
# shape: (B, 16, 55, 55) =>
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
# shape: (B, 32, 28, 28) =>
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,return_indices=True),
# output shape: (B, 64, 7, 7)
)
self.flat = nn.Flatten(start_dim=1) # shape: (B, 64*7*7)
self.encoder_fc = nn.Sequential(
# input shape: (B, 64*7*7)
nn.Linear(64*7*7, 1024),
nn.ReLU(True),
# shape: (B, 1024)
nn.Linear(1024, 8),
nn.Sigmoid()
# output shape: (B, 8)
)
def forward(self, x):
x = self.encoder_cnn(x)
x = self.flat(x)
x = self.encoder_fc(x)
return x
译码器
# B = Batch size
# decoder (B, 8) => (B, 3, 224, 224)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.decoder_fc = nn.Sequential(
nn.Linear(8, 1024),
nn.ReLU(True),
nn.Linear(1024, 64*7*7),
nn.ReLU(True)
)
self.unflat = nn.Unflatten(dim=1, unflattened_size=(64, 7, 7))
self.decoder_cnn = nn.Sequential(
nn.MaxUnpool2d(2),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1),
nn.MaxUnpool2d(2),
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=0),
nn.ReLU(True),
nn.ConvTranspose2d(8, 3, kernel_size=3, stride=1, padding=0)
)
def forward(self, x):
x = self.decoder_fc(x)
x = self.unflat(x)
x = self.decoder_cnn(x)
return x
当我测试编码器时,我得到了这个错误
encoder = Encoder().to(device)
decoder = Decoder().to(device)
test_img = torch.unsqueeze(train_data[0], dim=0)
print(encoder(test_img))
谢谢你的帮助:p
ps:我试图删除nn.MaxPool2d(2,return_indices=True)
中的return_indices=True
,编码器将成功运行但是当我运行这个时,会有另一个错误:
encoder = Encoder().to(device)
decoder = Decoder().to(device)
test_img = torch.unsqueeze(train_data[0], dim=0)
codes = encoder(test_img)
print(decoder(codes).shape)
错误:TypeError: forward()缺少1个必需的位置参数:' indexes '
问题
当return_indices=True
,nn.MaxPool2d.forward
返回一个元组(out, indices)
。稍后,nn.MaxUnpool2d
需要indices
。但是,您将Encoder
中的第一个nn.MaxPool2d
放在nn.Conv2d
之前的nn.Sequential
中。当return_indices=True
时,这是有问题的,因为返回的元组作为输入给nn.Conv2d
,但nn.Conv2d.forward
期望一个张量作为它的第一个参数。这就是为什么你会得到TypeError: conv2d() received an invalid combination of arguments
。所以你需要保留indices
,并确保只有out
给下一层。
解决方案修复方法是使用nn.MaxPool2d
作为分隔符拆分顺序模块。您还需要对Decoder
执行相同的操作。它看起来像
# B = Batch size
# encoder (B, 3, 224, 224) => (B, 8)
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder_cnn1 = nn.Sequential(
# input shape: (B, 3, 224, 224) =>
nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=0),
nn.ReLU(True),
nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0),
nn.ReLU(True),
nn.BatchNorm2d(16),
)
self.max_pool1 = nn.MaxPool2d(2,return_indices=True)
self.encoder_cnn2 = nn.Sequential(
# shape: (B, 16, 55, 55) =>
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
# shape: (B, 32, 28, 28) =>
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.BatchNorm2d(64),
)
self.max_pool2 = nn.MaxPool2d(2,return_indices=True)
self.flat = nn.Flatten(start_dim=1) # shape: (B, 64*7*7)
self.encoder_fc = nn.Sequential(
# input shape: (B, 64*7*7)
nn.Linear(64*7*7, 1024),
nn.ReLU(True),
# shape: (B, 1024)
nn.Linear(1024, 8),
nn.Sigmoid()
# output shape: (B, 8)
)
def forward(self, x):
x = self.encoder_cnn1(x)
x, indices1 = self.max_pool1(x)
x = self.encoder_cnn2(x)
x, indices2 = self.max_pool2(x)
x = self.flat(x)
x = self.encoder_fc(x)
return x, indices1, indices2 # also return the indices
# B = Batch size
# decoder (B, 8) => (B, 3, 224, 224)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.decoder_fc = nn.Sequential(
nn.Linear(8, 1024),
nn.ReLU(True),
nn.Linear(1024, 64*7*7),
nn.ReLU(True)
)
self.unflat = nn.Unflatten(dim=1, unflattened_size=(64, 7, 7))
self.max_unpool1 = nn.MaxUnpool2d(2)
self.decoder_cnn1 = nn.Sequential(
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1),
)
self.max_unpool2 = nn.MaxUnpool2d(2)
self.decoder_cnn2 = nn.Sequential(
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=0),
nn.ReLU(True),
nn.ConvTranspose2d(8, 3, kernel_size=3, stride=1, padding=0)
)
def forward(self, x, indices1, indices2): # accept the indices
x = self.decoder_fc(x)
x = self.unflat(x)
x = self.max_unpool1(x, indices2)
x = self.decoder_cnn1(x)
x = self.max_unpool2(x, indices1)
x = self.decoder_cnn2(x)
return x
然后你可以像这样运行编码器和解码器
encoder = Encoder()
decoder = Decoder()
test_img = torch.unsqueeze(torch.rand(3, 224, 224), dim=0)
codes, indices1, indices2 = encoder(test_img)
print(decoder(codes, indices1, indices2).shape)
<标题>警告h1> 面的代码运行时没有遇到与错误参数相关的TypeError
s,例如问题中的错误参数。然而,它提出了一个关于indices1
形状的不同错误。我怀疑这与解码器中最大解池层的内核大小或跨距或填充有关,但坦率地说,我对计算机视觉不够熟悉,无法调试此错误。也就是说,上面的代码确实解决了问题中的错误,所以我认为这篇文章仍然可以算作一个答案。
标题>标题>