我有一个形状为[bn, k, 2]
的张量。最后一个维度是坐标,我希望每个批次根据y坐标([:, :, 0]
(独立排序。我的方法看起来像这样:
import torch
a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[:, indices]
print(a)
print(a_sorted)
到目前为止还不错,但我现在它根据两个索引列表对两个批次进行排序,所以我总共得到了4个批次:
a
tensor([[[ 0.5160, 0.3257],
[-1.2410, -0.8361],
[ 1.3826, -1.1308],
[ 0.0338, 0.1665],
[-0.9375, -0.3081]],
[[ 0.4140, -1.0962],
[ 0.9847, -0.7231],
[-0.0110, 0.6437],
[-0.4914, 0.2473],
[-0.0938, -0.0722]]])
a_sorted
tensor([[[[-1.2410, -0.8361],
[-0.9375, -0.3081],
[ 0.0338, 0.1665],
[ 0.5160, 0.3257],
[ 1.3826, -1.1308]],
[[ 0.0338, 0.1665],
[-0.9375, -0.3081],
[ 1.3826, -1.1308],
[ 0.5160, 0.3257],
[-1.2410, -0.8361]]],
[[[ 0.9847, -0.7231],
[-0.0938, -0.0722],
[-0.4914, 0.2473],
[ 0.4140, -1.0962],
[-0.0110, 0.6437]],
[[-0.4914, 0.2473],
[-0.0938, -0.0722],
[-0.0110, 0.6437],
[ 0.4140, -1.0962],
[ 0.9847, -0.7231]]]])
正如你所看到的,我只想退回第一批和第四批。我该怎么做?
您想要的:a[0, indices[0]]
和a[1, indices[1]]
的串联。
您编码的内容:a[0, indices]
和a[1, indices]
的串联。
您面临的问题是,sort
返回的索引的形状类似于第一个维度,但值只是第二个维度的索引。当你使用这些时,你想在a[0]
上匹配indices[0]
,但pytorch并没有隐式地做到这一点(因为花式索引非常强大,它的强大之处需要这种语法(。所以,你所要做的就是为第一个维度提供一个平行的索引列表。
例如,您想要使用类似于a[[[0], [1]], indices]
的内容。
为了进一步概括这一点,您可以使用以下内容:
n = a.shape[0]
first_indices = torch.arange(n)[:, None]
a[first_indices, indices]
这是一个小把戏,所以这里有一个例子:
>>> a = torch.randn(2,4,2)
>>> a
tensor([[[-0.2050, -0.1651],
[ 0.5688, 1.0082],
[-1.5964, -0.9236],
[ 0.3093, -0.2445]],
[[ 1.0586, 1.0048],
[ 0.0893, 2.4522],
[ 2.1433, -1.2428],
[ 0.1591, 2.4945]]])
>>> indices = a[:, :, 0].sort()[1]
>>> indices
tensor([[2, 0, 3, 1],
[1, 3, 0, 2]])
>>> a[:, indices]
tensor([[[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]],
[[ 0.5688, 1.0082],
[ 0.3093, -0.2445],
[-0.2050, -0.1651],
[-1.5964, -0.9236]]],
[[[ 2.1433, -1.2428],
[ 1.0586, 1.0048],
[ 0.1591, 2.4945],
[ 0.0893, 2.4522]],
[[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]]]])
>>> a[0, indices]
tensor([[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]],
[[ 0.5688, 1.0082],
[ 0.3093, -0.2445],
[-0.2050, -0.1651],
[-1.5964, -0.9236]]])
>>> a[1, indices]
tensor([[[ 2.1433, -1.2428],
[ 1.0586, 1.0048],
[ 0.1591, 2.4945],
[ 0.0893, 2.4522]],
[[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]]])
>>> a[0, indices[0]]
tensor([[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]])
>>> a[1, indices[1]]
tensor([[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]])
>>> a[[[0], [1]], indices]
tensor([[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]],
[[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]]])