如何在nn中使用自定义torch.autograd.Function.顺序模型



有没有任何方法可以在nn.Sequential对象中使用自定义torch.autograd.Function,或者应该显式使用具有正向函数的nn.Module对象。具体来说,我正在尝试实现一个稀疏自动编码器,我需要将代码的L1距离(隐藏表示(添加到损失中。我在下面定义了自定义torch.autograd.FunctionL1Penalty,然后尝试在nn.Sequential对象中使用它,如下所示。然而,当我运行时,我得到了错误TypeError: __main__.L1Penalty is not a Module subclass。我该如何解决这个问题?

class L1Penalty(torch.autograd.Function):
@staticmethod
def forward(ctx, input, l1weight = 0.1):
ctx.save_for_backward(input)
ctx.l1weight = l1weight
return input, None
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
grad_input = input.clone().sign().mul(ctx.l1weight)
grad_input+=grad_output
return grad_input
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 6),
nn.ReLU(),
# sparsity
L1Penalty(),
nn.Linear(6, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU()
).to(device)

正确的方法是这个

import torch, torch.nn as nn
class L1Penalty(torch.autograd.Function):
@staticmethod
def forward(ctx, input, l1weight = 0.1):
ctx.save_for_backward(input)
ctx.l1weight = l1weight
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
grad_input = input.clone().sign().mul(ctx.l1weight)
grad_input+=grad_output
return grad_input

创建一个充当包装的Lambda类

class Lambda(nn.Module):
"""
Input: A Function
Returns : A Module that can be used
inside nn.Sequential
"""
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x): return self.func(x)

TA-DA!

model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 6),
nn.ReLU(),
# sparsity
Lambda(L1Penalty.apply),
nn.Linear(6, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU())
a = torch.rand(50,10)
b = model(a)
print(b.shape)

nn.ModuleAPI似乎工作正常,但不应在L1Penaltyforward方法中返回None。

import torch, torch.nn as nn
class L1Penalty(torch.autograd.Function):
@staticmethod
def forward(ctx, input, l1weight = 0.1):
ctx.save_for_backward(input)
ctx.l1weight = l1weight
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
grad_input = input.clone().sign().mul(ctx.l1weight)
grad_input+=grad_output
return grad_input

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,10)
self.fc2 = nn.Linear(10,6)
self.fc3 = nn.Linear(6,10)
self.fc4 = nn.Linear(10,10)
self.relu = nn.ReLU(inplace=True)
self.penalty = L1Penalty()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.penalty.apply(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
x = self.relu(x)
return x

model = Model()
a = torch.rand(50,10)
b = model(a)
print(b.shape)

相关内容

  • 没有找到相关文章

最新更新