Pytorch -如果你分离一个nn.网络中间的模块之前的所有模块都没有计算出它们的梯度吗?



假设我有一个输入X和一个由网络a,网络B和网络C组成的连续网络,如果我分离网络B,让X通过a ->B->C,因为B是分离的,我是否会失去a的梯度信息?我想没有吧?我假设它只是把B当作一个常数加到a的输出上,而不是一个可微的东西。

TLDR;阻止B的梯度计算并不会阻止上游网络A的梯度计算。


我认为你对"分离模型"的理解有些混乱。在我看来,有三件事要记住这类事情:

  • 你可以detach一个张量,它有效地从计算图中分离出来,也就是说,如果这个张量被用来计算另一个需要梯度的张量,反向传播步骤将不会传播到这个"分离的"张量;张量。

  • 在您描述"分离模型"的方式中,您可以通过在其参数上将requires_grad切换为False来禁用网络给定层上的梯度计算。这可以用nn.Module.requires_grad_在模块级的一行中完成。因此,在您的情况下,执行B.requires_grad_(False)将冻结B的参数,使它们无法更新。换句话说,B的参数梯度不会被计算,但是用于传播到A的中间梯度!下面是一个简单的例子:

    >>> A = nn.Linear(10,10)
    >>> B = nn.Linear(10,10)
    >>> C = nn.Linear(10,10)
    # disable gradient computation on B
    >>> B.requires_grad_(False)
    # dummy input, inference, and backpropagation
    >>> x = torch.rand(1,10, requires_grad=True)
    >>> C(B(A(x))).mean().backward()
    
    我们现在可以检查C的梯度A确实被正确填充了:
    >>> A.weight.grad.sum()
    tensor(0.3281)
    >>> C.weight.grad.sum()
    tensor(-1.6335)
    

    当然,B.weight.grad返回None

  • 最后,另一种行为是在使用no_grad上下文管理器时。这有效地消除了渐变。如果你这样做:

    >>> yA = A(x)
    >>> with torch.no_grad():
    ...    yB = B(yA)
    >>> yC = C(yB)
    

    此处yC已脱离网络。

最新更新