我无法理解torch.autograd.grad
中grad_outputs
选项的概念含义。
文档说:
grad_outputs
应该是包含雅可比向量积中的"向量"的长度匹配输出序列,通常是每个输出中预先计算的梯度。如果输出不require_grad
,则梯度可以None
)。
我觉得这个描述很晦涩难懂。雅可比向量积到底是什么意思?我知道雅可比是什么,但不确定它们在这里是什么意思:元素方面,矩阵积,别的什么?我无法从下面的例子中分辨出来。
为什么"向量">在引号中?事实上,在下面的示例中,当grad_outputs
是向量时,我会收到错误,但当它是矩阵时,我不会出错。
>>> x = torch.tensor([1.,2.,3.,4.], requires_grad=True)
>>> y = torch.outer(x, x)
为什么我们观察到以下输出;它是如何计算的?
>>> y
tensor([[ 1., 2., 3., 4.],
[ 2., 4., 6., 8.],
[ 3., 6., 9., 12.],
[ 4., 8., 12., 16.]], grad_fn=<MulBackward0>)
>>> torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))
(tensor([20., 20., 20., 20.]),)
但是,为什么会出现此错误?
>>> torch.autograd.grad(y, x, grad_outputs=torch.ones_like(x))
运行时错误:形状不匹配:
grad_output[0]
的形状为torch.Size([4])
,output[0]
的形状为torch.Size([4, 4])
。
如果我们以你为例,我们有函数f
,它x
形状(n,)
作为输入,输出y = f(x)
形状(n, n)
。输入被描述为列向量[x_i]_i for i ∈ [1, n]
,f(x)
被定义为矩阵[y_jk]_jk = [x_j*x_k]_jk for j, k ∈ [1, n]²
。
计算输出相对于输入的梯度通常很有用(或者有时 w.r.tf
的参数,这里没有)。在更一般的情况下,我们希望计算dL/dx
而不仅仅是dy/dx
,其中dL/dx
是L
的偏导数,从y
计算,w.r.t。x
.
计算图如下所示:
x.grad = dL/dx <------- dL/dy y.grad
dy/dx
x -------> y = x*xT
然后,如果我们看一下dL/dx
,即通过链式规则等于dL/dy*dy/dx
.查看torch.autograd.grad
的界面,我们有以下对应关系:
outputs
<->y
,inputs
<->x
,以及grad_outputs
<->dL/dy
.
看形状:dL/dx
应该与x
具有相同的形状(dL/dx
可以称为x
的"梯度"),而雅可比矩阵dy/dx
将是三维的。另一方面,dL/dy
,即输入的梯度,应该与输出具有相同的形状,即y
的形状。
我们要计算dL/dx = dL/dy*dy/dx
.如果我们更仔细地观察,我们有
dy/dx = [dy_jk/dx_i]_ijk for i, j, k ∈ [1, n]³
因此
dL/dx = [dL/d_x_i]_i, i ∈ [1,n]
= [sum(dL/dy_jk * d(y_jk)/dx_i over j, k ∈ [1, n]²]_i, i ∈ [1,n]
回到您的示例,这意味着对于给定的i ∈ [1, n]
:dL/dx_i = sum(dy_jk/dx_i) over j, k ∈ [1,n]²
。如果i = k
,dy_jk/dx_i = f(x_j*x_k)/dx_i
等于x_j
,如果i = j
,x_k
等于2*x_i
如果i = j = k
(因为平方x_i
)。话虽如此,矩阵y
是对称的...所以结果归结为2*sum(x_i) over i ∈ [1, n]
这意味着列向量dL/dx
[2*sum(x)]_i for i ∈ [1, n]
。
>>> 2*x.sum()*torch.ones_like(x)
tensor([20., 20., 20., 20.])
退一步看看这个其他图形示例,这里在y
之后添加一个额外的操作:
x -------> y = x*xT --------> z = y²
如果您查看此图上的向后传递,您将得到:
dL/dx <------- dL/dy <-------- dL/dz
dy/dx dz/dy
x -------> y = x*xT --------> z = y²
dL/dx = dL/dy*dy/dx = dL/dz*dz/dy*dy/dx
实际上分两个顺序步骤计算:dL/dy = dL/dz*dz/dy
,然后dL/dx = dL/dy*dy/dx
。