我想删除自动编码器的解码器部分。
我想将 FC 放在移除的部分。
此外,编码器部件不会使用预先学习的权重进行训练。
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 8, 3, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(8, 8, 3, padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=4, stride=1),
)
self.decoder = nn.Sequential(
nn.Conv2d(8, 8, 3, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(8, 8, kernel_size=2, stride=2),
nn.Conv2d(8, 8, 3, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(8, 8, kernel_size=2, stride=2),
nn.Conv2d(8, 16, 3),
nn.ReLU(True),
nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2),
nn.Conv2d(16, 1, 3, padding=1)
)
def forward(self, x):
if self.training :
x = self.encoder(x)
x = self.decoder(x)
return x
else:
x = classifier(x)
return x
这可能吗?帮帮我。。。
一个简单而干净的解决方案是将一个独立的网络定义为解码器,然后在预训练结束后用这个新网络替换模型的解码器属性。下面简单的例子:
class sillyExample(torch.nn.Module):
def __init__(self):
super(sillyExample, self).__init__()
self.encoder = torch.nn.Linear(5, 5)
self.decoder = torch.nn.Linear(5, 10)
def forward(self, x):
return self.decoder(self.encoder(x))
test = sillyExample()
test(torch.rand(30, 5)).shape
Out: torch.Size([30, 10])
test.decoder = torch.nn.Linear(5, 20) # replace the decoder
test(torch.rand(30, 5)).shape
Out: torch.Size([30, 20])
只需确保使用更新的模型(或可能引用模型参数的任何其他内容)重新初始化优化器即可。