是否有一个很好的方法来修改pytorch张量中的一些值,同时保留autograd功能?



有时我需要修改pytorch张量中的一些值。例如,给定一个张量x,我需要将它的正部分乘以2,负部分乘以3:

import torch
x = torch.randn(1000, requires_grad=True)
x[x>0] = 2 * x[x>0]
x[x<0] = 3 * x[x<0]
y = x.sum()
y.backward()

但是这样的本地操作总是破坏autograd的图形:

Traceback (most recent call last):
File "test_rep.py", line 4, in <module>
x[x>0] = 2 * x[x>0]
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
因此,到目前为止,我一直在使用以下解决方案:
import torch
x = torch.randn(1000, requires_grad=True)
y = torch.zeros_like(x, device=x.device)
y[x>0] = 2 * x[x>0]
y[x<0] = 3 * x[x<0]
z = y.sum()
z.backward()

导致手动创建新的张量。我想知道是否有更好的方法来做这件事。

像下面这样怎么样?

import torch
x = torch.randn(1000, requires_grad=True)
x = torch.where(x>0, x*2, x)
x = torch.where(x<0, x*3, x)
y = x.sum()
y.backward()

有一个更好的方法,至少对于这个特定的情况,基于LeakyReLU的工作方式:

import torch
x = torch.randn(10, requires_grad=True)
y = 2 * torch.max(x, torch.tensor(0.0)) + 3 * torch.min(x, torch.tensor(0.0))

一些广播将用于0-dimtorch.tensor(0.0)

PyTorch不能很好地处理就地操作,因为张量被记录在磁带上(以及创建它们的操作)。如果你覆盖其中的一些值(这里是x>0x<0的张量),它们的历史将被覆盖而不是追加(就像在张量上应用不合适的操作一样,如上所述)。

很少有使用原地操作的情况(除非你真的受到内存使用的限制)。

最新更新