我有一个形状为(b,n)
的值张量val
和形状为(b,m)
的指标张量ind
(其中n>m
)。我的目标是取val
中对应于ind
中的索引的值。我试过使用val[ind]
,但它只扩展了val
的维度,而不是只取相关的项目
val = torch.tensor([[1,2,3],
[4,5,6],
[7,8,9],
[10,11,12],
[13,14,15]])
ind = torch.tensor([[1,2],
[0,2],
[0,1],
[1,2],
[0,1]])
val[ind] # shaped (5,2,4), I need (5,2)
需要的输出是
torch.tensor([[2,3],
[4,6],
[7,8],
[11,12],
[13,14]])
您可以使用torch.gather
:
>>> val.gather(dim=1, index=ind)
tensor([[ 2, 3],
[ 4, 6],
[ 7, 8],
[11, 12],
[13, 14]])
本质上是用ind
的值索引val
的第二维。返回的张量out
如下:
out[i][j] = val[i][ind[i]]