是否可以在自动编码器之后添加可训练滤波器



所以我正在构建一个带有自动编码器的去噪器。这个想法是,在计算我的损失之前(在自动编码器之后(,我将经验维纳滤波器应用于图像的纹理图,并将其添加回我的自动编码器输出(添加回"丢失的细节"(。我已经用PyTorch对这个过滤器进行了编码。

我的第一次尝试是在自动编码器的正向函数末尾添加过滤器。我可以训练这个网络,它在训练中通过我的过滤器反向传播。但是,如果我打印我的网络,过滤器不会列出,并且torchsummary在计算参数时不包括它。

这让我觉得我只是在训练自动编码器,而我的过滤器每次都以相同的方式进行过滤,而不是学习。

我想做的事情可能吗?

下面是我的自动编码器:

class AutoEncoder(nn.Module):
"""Autoencoder simple implementation """
def __init__(self):
super(AutoEncoder, self).__init__()
# Encoder
# conv layer
self.block1 = nn.Sequential(
nn.Conv2d(1, 48, 3, padding=1),
nn.Conv2d(48, 48, 3, padding=1),
nn.MaxPool2d(2),
nn.BatchNorm2d(48),
nn.LeakyReLU(0.1)
)
self.block2 = nn.Sequential(
nn.Conv2d(48, 48, 3, padding=1),
nn.MaxPool2d(2),
nn.BatchNorm2d(48),
nn.LeakyReLU(0.1)
)
self.block3 = nn.Sequential(
nn.Conv2d(48, 48, 3, padding=1),
nn.ConvTranspose2d(48, 48, 2, 2, output_padding=1),
nn.BatchNorm2d(48),
nn.LeakyReLU(0.1)
)
self.block4 = nn.Sequential(
nn.Conv2d(96, 96, 3, padding=1),
nn.Conv2d(96, 96, 3, padding=1),
nn.ConvTranspose2d(96, 96, 2, 2),
nn.BatchNorm2d(96),
nn.LeakyReLU(0.1)
)
self.block5 = nn.Sequential(
nn.Conv2d(144, 96, 3, padding=1),
nn.Conv2d(96, 96, 3, padding=1),
nn.ConvTranspose2d(96, 96, 2, 2),
nn.BatchNorm2d(96),
nn.LeakyReLU(0.1)
)
self.block6 = nn.Sequential(
nn.Conv2d(97, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.Conv2d(64, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 1, 3, padding=1),
nn.LeakyReLU(0.1)
)
# self.blockNorm = nn.Sequential(
#     nn.BatchNorm2d(1),
#     nn.LeakyReLU(0.1)
# )
def forward(self, x):
# torch.autograd.set_detect_anomaly(True)
# print("input: ", x.shape)
pool1 = self.block1(x)
# print("pool1: ", pool1.shape)
pool2 = self.block2(pool1)
# print("pool2: ", pool2.shape)
pool3 = self.block2(pool2)
# print("pool3: ", pool3.shape)
pool4 = self.block2(pool3)
# print("pool4: ", pool4.shape)
pool5 = self.block2(pool4)
# print("pool5: ", pool5.shape)
upsample5 = self.block3(pool5)
# print("upsample5: ", upsample5.shape)
concat5 = torch.cat((upsample5, pool4), 1)
# print("concat5: ", concat5.shape)
upsample4 = self.block4(concat5)
# print("upsample4: ", upsample4.shape)
concat4 = torch.cat((upsample4, pool3), 1)
# print("concat4: ", concat4.shape)
upsample3 = self.block5(concat4)
# print("upsample3: ", upsample3.shape)
concat3 = torch.cat((upsample3, pool2), 1)
# print("concat3: ", concat3.shape)
upsample2 = self.block5(concat3)
# print("upsample2: ", upsample2.shape)
concat2 = torch.cat((upsample2, pool1), 1)
# print("concat2: ", concat2.shape)
upsample1 = self.block5(concat2)
# print("upsample1: ", upsample1.shape)
concat1 = torch.cat((upsample1, x), 1)
# print("concat1: ", concat1.shape)
output = self.block6(concat1)
t_map = x - output
for i in range(4):
tensor = t_map[i, :, :, :]                 # Take each item in batch separately. Could account for this in Wiener instead
tensor = torch.squeeze(tensor)              # Squeeze for Wiener input format
tensor = wiener_3d(tensor, 0.05, 10)        # Apply Wiener with specified std and block size
tensor = torch.unsqueeze(tensor, 0)         # unsqueeze to put back into block
t_map[i, :, :, :] = tensor                  # put back into block
filtered_output = output + t_map
return filtered_output

最后的for循环是将过滤器应用于批处理中的每个图像。我知道这是不可并行的,所以如果有人对此有想法,我会很感激。如果有帮助的话,我可以发布"wiener 3d(("过滤器函数,只是想保持帖子简短。

我试图定义一个带有过滤器的自定义图层类,但很快就迷失了方向。

如有任何帮助,我们将不胜感激!

如果你只想把Wiener过滤器变成一个模块,那么以下操作就可以了:

class WienerFilter(T.nn.Module):
def __init__(self, param_a=0.05, param_b=10):
super(WienerFilter, self).__init__()
# This can be accessed like any other member via self.param_a
self.register_parameter("param_a", T.nn.Parameter(T.tensor(param_a)))
self.param_b = param_b
def forward(self, input):
for i in range(4):
tensor = input[i]                
tensor = torch.squeeze(tensor)
tensor = wiener_3d(tensor, self.param_a, self.param_b)
tensor = torch.unsqueeze(tensor, 0)
input[i] = tensor 
return input  

您可以通过添加一行来应用此功能

self.wiener_filter = WienerFilter()

在AutoEncoder的init函数中。

在前进中,然后用替换for循环

filtered_output = output + self.wiener_filter(t_map)

Torch知道wiener_filter模块是一个成员模块,所以如果您打印AutoEncoder的模块,它会列出该模块。

如果你想并行化你的wiener滤波器,你需要用PyTorch的术语来做,这意味着使用它在张量上的运算。这些操作是以并行方式实现的。

最新更新