RuntimeError:试图第二次向后遍历图形,但pytorch已经释放了缓冲区



如何在第二次调用.backward()之前清除渐变。

RuntimeError:试图第二次向后遍历图形,但已释放保存的中间结果。第一次向后调用时指定retain_graph=True

a = torch.tensor([2.0], requires_grad = True)
b = torch.tensor([2.0], requires_grad = True)
d = torch.tensor([2.0], requires_grad = True)
c=a*b
c.backward()
e = d*e
e.backward(retain_graph=True)

我试着这样做:c.zero_grad(),但我得到了错误c没有方法zero_grad()

当错误消息显示时,您需要在第一个.backward调用上指定retain_graph=True选项,而不是第二个:

c.backward(retain_graph=True)
e = d*c
e.backward()

如果不保留图形,则第二次反向传递将无法到达节点cab,因为第一次反向传递已清除激活。

最新更新