如何根据最后一个维度的坐标对三维张量进行排序(pytorch)



我有一个形状为[bn, k, 2]的张量。最后一个维度是坐标,我希望每个批次根据y坐标([:, :, 0](独立排序。我的方法看起来像这样:

import torch
a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[:, indices]
print(a)
print(a_sorted)

到目前为止还不错,但我现在它根据两个索引列表对两个批次进行排序,所以我总共得到了4个批次:

a
tensor([[[ 0.5160,  0.3257],
[-1.2410, -0.8361],
[ 1.3826, -1.1308],
[ 0.0338,  0.1665],
[-0.9375, -0.3081]],
[[ 0.4140, -1.0962],
[ 0.9847, -0.7231],
[-0.0110,  0.6437],
[-0.4914,  0.2473],
[-0.0938, -0.0722]]])
a_sorted
tensor([[[[-1.2410, -0.8361],
[-0.9375, -0.3081],
[ 0.0338,  0.1665],
[ 0.5160,  0.3257],
[ 1.3826, -1.1308]],
[[ 0.0338,  0.1665],
[-0.9375, -0.3081],
[ 1.3826, -1.1308],
[ 0.5160,  0.3257],
[-1.2410, -0.8361]]],

[[[ 0.9847, -0.7231],
[-0.0938, -0.0722],
[-0.4914,  0.2473],
[ 0.4140, -1.0962],
[-0.0110,  0.6437]],
[[-0.4914,  0.2473],
[-0.0938, -0.0722],
[-0.0110,  0.6437],
[ 0.4140, -1.0962],
[ 0.9847, -0.7231]]]])

正如你所看到的,我只想退回第一批和第四批。我该怎么做?

您想要的:a[0, indices[0]]a[1, indices[1]]的串联。

您编码的内容:a[0, indices]a[1, indices]的串联。

您面临的问题是,sort返回的索引的形状类似于第一个维度,但只是第二个维度的索引。当你使用这些时,你想在a[0]上匹配indices[0],但pytorch并没有隐式地做到这一点(因为花式索引非常强大,它的强大之处需要这种语法(。所以,你所要做的就是为第一个维度提供一个平行的索引列表。

例如,您想要使用类似于a[[[0], [1]], indices]的内容。

为了进一步概括这一点,您可以使用以下内容:

n = a.shape[0]
first_indices = torch.arange(n)[:, None]
a[first_indices, indices]

这是一个小把戏,所以这里有一个例子:

>>> a = torch.randn(2,4,2)
>>> a
tensor([[[-0.2050, -0.1651],
[ 0.5688,  1.0082],
[-1.5964, -0.9236],
[ 0.3093, -0.2445]],
[[ 1.0586,  1.0048],
[ 0.0893,  2.4522],
[ 2.1433, -1.2428],
[ 0.1591,  2.4945]]])
>>> indices = a[:, :, 0].sort()[1]
>>> indices
tensor([[2, 0, 3, 1],
[1, 3, 0, 2]])
>>> a[:, indices]
tensor([[[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688,  1.0082]],
[[ 0.5688,  1.0082],
[ 0.3093, -0.2445],
[-0.2050, -0.1651],
[-1.5964, -0.9236]]],

[[[ 2.1433, -1.2428],
[ 1.0586,  1.0048],
[ 0.1591,  2.4945],
[ 0.0893,  2.4522]],
[[ 0.0893,  2.4522],
[ 0.1591,  2.4945],
[ 1.0586,  1.0048],
[ 2.1433, -1.2428]]]])
>>> a[0, indices]
tensor([[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688,  1.0082]],
[[ 0.5688,  1.0082],
[ 0.3093, -0.2445],
[-0.2050, -0.1651],
[-1.5964, -0.9236]]])
>>> a[1, indices]
tensor([[[ 2.1433, -1.2428],
[ 1.0586,  1.0048],
[ 0.1591,  2.4945],
[ 0.0893,  2.4522]],
[[ 0.0893,  2.4522],
[ 0.1591,  2.4945],
[ 1.0586,  1.0048],
[ 2.1433, -1.2428]]])
>>> a[0, indices[0]]
tensor([[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688,  1.0082]])
>>> a[1, indices[1]]
tensor([[ 0.0893,  2.4522],
[ 0.1591,  2.4945],
[ 1.0586,  1.0048],
[ 2.1433, -1.2428]])
>>> a[[[0], [1]], indices]
tensor([[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688,  1.0082]],
[[ 0.0893,  2.4522],
[ 0.1591,  2.4945],
[ 1.0586,  1.0048],
[ 2.1433, -1.2428]]])

最新更新