提取给定张量索引值的张量值



我有一个形状为(b,n)的值张量val和形状为(b,m)的指标张量ind(其中n>m)。我的目标是取val中对应于ind中的索引的值。我试过使用val[ind],但它只扩展了val的维度,而不是只取相关的项目

val = torch.tensor([[1,2,3],
[4,5,6],
[7,8,9],
[10,11,12],
[13,14,15]])   
ind = torch.tensor([[1,2],
[0,2],
[0,1],
[1,2],
[0,1]])
val[ind] # shaped (5,2,4), I need (5,2)

需要的输出是

torch.tensor([[2,3],
[4,6],
[7,8],
[11,12],
[13,14]])

您可以使用torch.gather:

执行此操作。
>>> val.gather(dim=1, index=ind)
tensor([[ 2,  3],
[ 4,  6],
[ 7,  8],
[11, 12],
[13, 14]])

本质上是用ind的值索引val的第二维。返回的张量out如下:

out[i][j] = val[i][ind[i]]

相关内容

  • 没有找到相关文章

最新更新