假设我有一个输入X和一个由网络a,网络B和网络C组成的连续网络,如果我分离网络B,让X通过a ->B->C,因为B是分离的,我是否会失去a的梯度信息?我想没有吧?我假设它只是把B当作一个常数加到a的输出上,而不是一个可微的东西。
TLDR;阻止B
的梯度计算并不会阻止上游网络A
的梯度计算。
我认为你对"分离模型"的理解有些混乱。在我看来,有三件事要记住这类事情:
-
你可以
detach
一个张量,它有效地从计算图中分离出来,也就是说,如果这个张量被用来计算另一个需要梯度的张量,反向传播步骤将不会传播到这个"分离的"张量;张量。 -
在您描述"分离模型"的方式中,您可以通过在其参数上将
requires_grad
切换为False
来禁用网络给定层上的梯度计算。这可以用nn.Module.requires_grad_
在模块级的一行中完成。因此,在您的情况下,执行B.requires_grad_(False)
将冻结B
的参数,使它们无法更新。换句话说,B
的参数梯度不会被计算,但是用于传播到A
的中间梯度将!下面是一个简单的例子:
我们现在可以检查C和的梯度A确实被正确填充了:>>> A = nn.Linear(10,10) >>> B = nn.Linear(10,10) >>> C = nn.Linear(10,10) # disable gradient computation on B >>> B.requires_grad_(False) # dummy input, inference, and backpropagation >>> x = torch.rand(1,10, requires_grad=True) >>> C(B(A(x))).mean().backward()
>>> A.weight.grad.sum() tensor(0.3281) >>> C.weight.grad.sum() tensor(-1.6335)
当然,
B.weight.grad
返回None
。 -
最后,另一种行为是在使用
no_grad
上下文管理器时。这有效地消除了渐变。如果你这样做:>>> yA = A(x) >>> with torch.no_grad(): ... yB = B(yA) >>> yC = C(yB)
此处
yC
已脱离网络。