PyTorch用向量切片矩阵



假设我有一个矩阵和一个向量,如下所示:

import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])

有没有办法把它分割成x[y],所以结果是:

res = [1, 6, 8]

所以基本上我取y的第一个元素,取x中对应于第一行和元素列的元素。

您可以将相应的行索引指定为:

import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
x[range(x.shape[0]), y]
tensor([1, 6, 8])

pytorch中的高级索引与NumPy's一样工作,即索引数组在轴上一起广播。所以你可以按照FBruzzesi的回答去做。

虽然与np.take_along_axis类似,但在pytorch中,您也有torch.gather,可以沿特定轴获取值:

x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])

最新更新