有相当于tf.custom_gradient()的PyTorch吗



我是PyTorch的新手,但对TensorFlow有很多经验。

我想修改图中一小块的梯度:单层激活函数的导数。这可以在Tensorflow中使用tf.custom_gradient轻松完成,它允许您为任何函数提供自定义的梯度。

我想在PyTorch中做同样的事情,我知道你可以修改backward((方法,但这需要你重写forward((法中定义的整个网络的导数,而我只想修改图中一小块的梯度。PyTorch中有类似tf.custom_gradient((的东西吗?谢谢

您可以通过以下两种方式实现:

1.修改backward()函数:
正如您在问题中所说,pytorch还允许您提供自定义的backward实现。然而,与您所写的相反,您确实不需要重写整个模型的backward(),只需要重写要更改的特定层的backward()
这里有一个简单而漂亮的教程,展示了如何做到这一点。

例如,这里有一个自定义的clip激活,它不是杀死[0, 1]域之外的梯度,而是简单地按如下方式传递梯度:

class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.clip(x, 0., 1.)
@staticmethod
def backward(ctx, grad):
return grad

现在,您可以在模型中的任何位置使用MyClip层,并且不需要担心backward的整体功能。


2.使用backward挂钩pytorch可以将挂钩连接到网络的不同层(=子nn.Modules(。您可以将register_full_backward_hook添加到您的图层。挂钩功能可以修改梯度:

钩子不应该修改其参数,但它可以选择返回一个关于输入的新梯度,该梯度将在后续计算中代替grad_input使用。

相关内容

  • 没有找到相关文章

最新更新