if-else语句和torch之间有什么区别?pytorch中的where



请参阅代码片段:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

输出为tensor([0.]),但

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
y = x
else:
y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

输出为None

我很困惑为什么torch.where的输出是tensor([0.])

更新

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b
(a[0, 0] * a[0, 1]).backward()
print(b.grad)

输出为tensor([2., 0.])(a[0, 0] * a[0, 1])b[1]没有任何关系,但b[1]的梯度是0而不是None

基于跟踪的AD与pytorch一样,通过跟踪来工作。您无法跟踪库截获的非函数调用的内容。通过使用这样的if语句,xy之间没有连接,而对于wherexy在表达式树中是链接的。

现在,对于差异:

  • 在第一个片段中,0是函数x ↦ x > 0 ? x : 2在点-1的正确导数(因为负侧是常数(
  • 在第二个片段中,正如我所说,xy(在else分支中(没有任何关系。因此,给定xy的导数是未定义的,其表示为None

(即使在Python中也可以做这样的事情,但这需要更复杂的技术,比如源代码转换。我认为pytorch不可能做到这一点。(

相关内容

  • 没有找到相关文章

最新更新