假设我有一个矩阵和一个向量,如下所示:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
有没有办法把它分割成x[y]
,所以结果是:
res = [1, 6, 8]
所以基本上我取y
的第一个元素,取x
中对应于第一行和元素列的元素。
您可以将相应的行索引指定为:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
x[range(x.shape[0]), y]
tensor([1, 6, 8])
pytorch中的高级索引与NumPy's
一样工作,即索引数组在轴上一起广播。所以你可以按照FBruzzesi的回答去做。
虽然与np.take_along_axis
类似,但在pytorch中,您也有torch.gather
,可以沿特定轴获取值:
x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])