我的问题是关于pytorch register_hook
的语法。
x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y
x.register_hook(print)
y.register_hook(print)
z.backward()
输出:
tensor([2.])
tensor([4.])
此片段仅分别打印z
w.r.t x
和 y
的梯度。
现在我(很可能是微不足道的(问题是如何返回中间渐变(而不仅仅是打印(?
更新:
似乎调用retain_grad()
解决了叶节点的问题。 y.retain_grad()
.
但是,对于非叶节点,retain_grad
似乎并不能解决它。有什么建议吗?
我认为您可以使用这些钩子将梯度存储在全局变量中:
grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()
但您很可能还需要记住计算这些梯度的相应张量。在这种情况下,我们使用dict
而不是list
稍微扩展上面:
grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
print(grad,parent)
grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()
例如,现在您可以简单地使用grads[y]
访问张量y
的 grad