PyTorch's torch.autograd.grad 中grad_outputs的含义



我无法理解torch.autograd.gradgrad_outputs选项的概念含义。

文档说:

grad_outputs应该是包含雅可比向量积中的"向量"的长度匹配输出序列,通常是每个输出中预先计算的梯度。如果输出不require_grad,则梯度可以None)。

我觉得这个描述很晦涩难懂。雅可比向量积到底是什么意思?我知道雅可比是什么,但不确定它们在这里是什么意思:元素方面,矩阵积,别的什么?我无法从下面的例子中分辨出来。

为什么"向量">在引号中?事实上,在下面的示例中,当grad_outputs是向量时,我会收到错误,但当它是矩阵时,我不会出错。

>>> x = torch.tensor([1.,2.,3.,4.], requires_grad=True)
>>> y = torch.outer(x, x)

为什么我们观察到以下输出;它是如何计算的?

>>> y
tensor([[ 1.,  2.,  3.,  4.],
[ 2.,  4.,  6.,  8.],
[ 3.,  6.,  9., 12.],
[ 4.,  8., 12., 16.]], grad_fn=<MulBackward0>)
>>> torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))
(tensor([20., 20., 20., 20.]),)

但是,为什么会出现此错误?

>>> torch.autograd.grad(y, x, grad_outputs=torch.ones_like(x))  

运行时错误:形状不匹配:grad_output[0]的形状为torch.Size([4])output[0]的形状为torch.Size([4, 4])

如果我们以你为例,我们有函数f,它x形状(n,)作为输入,输出y = f(x)形状(n, n)。输入被描述为列向量[x_i]_i for i ∈ [1, n]f(x)被定义为矩阵[y_jk]_jk = [x_j*x_k]_jk for j, k ∈ [1, n]²

计算输出相对于输入的梯度通常很有用(或者有时 w.r.tf的参数,这里没有)。在更一般的情况下,我们希望计算dL/dx而不仅仅是dy/dx,其中dL/dxL的偏导数,y计算,w.r.t。x.

计算图如下所示:

x.grad = dL/dx <-------   dL/dy y.grad
dy/dx
x       ------->    y = x*xT

然后,如果我们看一下dL/dx,即通过链式规则等于dL/dy*dy/dx.查看torch.autograd.grad的界面,我们有以下对应关系:

  • outputs<->y
  • inputs<->x,以及
  • grad_outputs<->dL/dy.

看形状:dL/dx应该与x具有相同的形状(dL/dx可以称为x的"梯度"),而雅可比矩阵dy/dx将是三维的。另一方面,dL/dy,即输入的梯度,应该与输出具有相同的形状,y的形状。

我们要计算dL/dx = dL/dy*dy/dx.如果我们更仔细地观察,我们有

dy/dx = [dy_jk/dx_i]_ijk for i, j, k ∈ [1, n]³

因此

dL/dx = [dL/d_x_i]_i, i ∈ [1,n]
= [sum(dL/dy_jk * d(y_jk)/dx_i over j, k ∈ [1, n]²]_i, i ∈ [1,n]

回到您的示例,这意味着对于给定的i ∈ [1, n]dL/dx_i = sum(dy_jk/dx_i) over j, k ∈ [1,n]²。如果i = kdy_jk/dx_i = f(x_j*x_k)/dx_i等于x_j,如果i = jx_k等于2*x_i如果i = j = k(因为平方x_i)。话虽如此,矩阵y是对称的...所以结果归结为2*sum(x_i) over i ∈ [1, n]

这意味着列向量dL/dx[2*sum(x)]_i for i ∈ [1, n]

>>> 2*x.sum()*torch.ones_like(x)
tensor([20., 20., 20., 20.])

退一步看看这个其他图形示例,这里在y之后添加一个额外的操作:

x   ------->  y = x*xT  -------->  z = y²

如果您查看此图上的向后传递,您将得到:

dL/dx <-------   dL/dy    <--------  dL/dz
dy/dx              dz/dy 
x   ------->  y = x*xT  -------->  z = y²

dL/dx = dL/dy*dy/dx = dL/dz*dz/dy*dy/dx实际上分两个顺序步骤计算:dL/dy = dL/dz*dz/dy,然后dL/dx = dL/dy*dy/dx

最新更新