pytorch中的张量切片导致torch分配不到位



我正在尝试将当前代码转换为外部操作,该代码在适当的位置分配张量
表示当前代码为

self.X[:, nc:] = D

其中D与self.X[:, nc:]形状相同
但我想将其转换为

sliced_index = ~ somehow create an indexed tensor from self.X[:, nc:]
self.X = self.X.scatter(1,sliced_index,mm(S_, Z[:, :n - nc]))

不知道如何创建只表示切片张量中的条目的索引掩码张量

最小示例:

a = [[0,1,2],[3,4,5]]
D = [[6],[7]]
Not_in_place = [[0,1,6],[3,4,7]]

掩蔽散射稍微容易一些。掩码本身可以作为就地操作进行计算,之后可以使用masked_scatter

mask = torch.zeros(self.X.shape, device=self.X.device, dtype=torch.bool)
mask[:, nc:] = True
self.X = self.X.masked_scatter(mask, D)

是一个更专业的版本,它依赖于广播,但应该更高效

mask = torch.zeros([1, self.X.size(1)], device=self.X.device, dtype=torch.bool)
mask[0, nc:] = True
self.X = self.X.masked_scatter(mask, D)

使用Tensor.clone复制张量。

a = torch.tensor([[0,1,2],[3,4,5]])
D = torch.tensor([[6],[7]])
n, n[:,-1:] = a.clone(), D
n
tensor([[0, 1, 6],
[3, 4, 7]])
a
tensor([[0, 1, 2],
[3, 4, 5]])

最新更新