在pytorch中寻找用另一个2D张量索引2D张量的有效方法



我有一个张量,比如说

A = tensor([
[0, 0],
[0, 2],
[0, 3],
[0, 4],
[0, 5],
[0, 6],
[1, 0],
[1, 1],
[1, 4],
[1, 5],
[1, 6]
])

和另一个张量

b = tensor([[0, 2], [1, 2]])

我想找到一种有效的方法,通过b索引到A,使结果是

result = tensor([[0, 3], [1, 4]])

即,将A的最后一个dim的第一列(即[0,…,1…])与b的最后一个dim的第一列(即[0,1])的值进行匹配,然后使用b的第二列(即[2,2])索引A的第二列。

感谢

用火炬将其转化为一维问题求解。非零且被掩码和偏移。

代替原来的A,得到一个平坦的版本,如

A = tensor([[ 0], [ 2], [ 3], [ 4], [ 5], [ 7], [ 8], [11], [12]])

并计算沿批处理的偏移量,

offset = tensor([[0], [5], [4]])
同样地,得到b
b = tensor([2, 2])

offset_b = b+offset.reshape(-1)[:-1]

然后

indices=A.reshape(-1)[offset_b]

相关内容

  • 没有找到相关文章

最新更新