PyTorch中标量展开的导数



我正致力于在Rust中实现一个非常简单的自动diff库,以扩展我对如何完成它的了解。我几乎所有的东西都在工作,但是当实现负对数似然时,我意识到我对如何处理以下场景的导数有一些困惑(我已经在下面的PyTorch中编写了它)。

x = torch.tensor([1, 2, 3], dtype=torch.float32, requires_grads=True)
y = x - torch.sum(x)

我四处看了看,做了实验,但我对这里到底发生了什么还是有点困惑。我知道上面的方程关于x的导数是[-2,-2,-2],但是有很多方法可以得到它,当我把方程展开成下面的式子时

x = torch.tensor([1, 2, 3], dtype=torch.float32, requires_grads=True)
y = torch.exp(x - torch.sum(x))

我完全迷路了,不知道它是如何推导出x的梯度的。

我假设上面的方程被改写成这样:

y = (x - [torch.sum(x), torch.sum(x), torch.sum(x)])

,但我不确定,我真的很难找到关于标量被扩展到向量或任何实际发生的话题的信息。如果有人能给我指出正确的方向,那就太棒了!

如果有帮助的话,我可以包括上面方程的梯度pytorch计算。

你的代码在没有任何修改的情况下不能与PyTorch一起工作,因为它没有指定w.r.t到y的梯度是什么。你需要它们调用y.backward()它计算从t到x的梯度。从你的全-2的结果,我认为梯度一定是全1。

"标量展开"叫做广播。您已经知道,只要两个张量操作数的形状不匹配,就会执行广播。我的猜测是,它的实现方式与PyTorch中的任何其他操作相同,该操作知道如何计算其输入的梯度,给定其输出的梯度。下面给出了一个简单的例子,(A)与给定的测试用例一起工作,(b)允许我们仍然使用PyTorch的autograd自动计算梯度(另请参阅PyTorch关于扩展autograd的文档):

class Broadcast(torch.autograd.Function):
def forward(ctx, x: torch.Tensor, length: int) -> torch.Tensor:
assert x.ndim == 0, "input must be a scalar tensor"
assert length > 0, "length must be greater than zero"
return x.new_full((length,), x.item())
def backward(ctx, grad: torch.Tensor) -> Tuple[torch.Tensor, None]:
return grad.sum(), None

现在,通过设置broadcast = Broadcast.apply,我们可以自己调用广播,而不是让PyTorch自动执行。

x = torch.tensor([1., 2., 3.], requires_grad=True)
y = x - broadcast(torch.sum(x), x.size(0))
y.backward(torch.ones_like(y))
assert torch.allclose(torch.tensor(-2.), x.grad)

注意,我不知道PyTorch实际上是如何实现它的。上面的实现只是为了说明如何编写广播操作来实现自动区分,希望它能回答您的问题。

首先,一些事情,参数是requires_grad而不是require_grads。其次,只能对浮点型或复杂dtype要求梯度。

现在,一个标量的加法/乘法(注意,减法/除法可以被看作是一个- 5数的加法/一个分数的乘法)只是将标量与张量的所有元素相加/相乘。因此,

x = torch.tensor([1., 2., 3.], requires_grad=True)
y = x - 1

评估:

y = tensor([-1.,  0.,  1.], grad_fn=<SubBackward0>)

因此,在你的例子中,torch.sum(x)基本上是一个标量,它从张量x的所有元素中减去。

如果你对渐变部分更感兴趣,请查看autograd [ref]上的pytorch文档。它声明如下:

用链式法则求导图。如果任何张量是非标量的(即它们的数据有多个元素)并且需要梯度,那么将计算雅可比向量积,在这种情况下,函数额外需要指定grad_tensors。它应该是一个匹配长度的序列,包含雅可比向量积中的"向量",通常是微分函数w.r.t.对应张量的梯度(None是所有不需要梯度张量的张量的可接受值)。

相关内容

  • 没有找到相关文章

最新更新