如何在pytorch中返回中间梯度(对于非叶节点)



我的问题是关于pytorch register_hook的语法。

x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y
x.register_hook(print)
y.register_hook(print)
z.backward()

输出:

tensor([2.])
tensor([4.])

此片段仅分别打印z w.r.t xy 的梯度。

现在我(很可能是微不足道的(问题是如何返回中间渐变(而不仅仅是打印(?

更新:

似乎调用retain_grad()解决了叶节点的问题。 y.retain_grad() .

但是,对于非叶节点,retain_grad似乎并不能解决它。有什么建议吗?

我认为您可以使用这些钩子将梯度存储在全局变量中:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()

但您很可能还需要记住计算这些梯度的相应张量。在这种情况下,我们使用dict而不是list稍微扩展上面:

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()

例如,现在您可以简单地使用grads[y]访问张量y的 grad

最新更新