对np.argpartition的工作理解有问题



我在执行np.argpartition时有问题这里是n。array

example = np.array([[5,6,7,3,4],[1,2,3,7,5],[6,7,4,2,3],[1,2,3,5,9],[2,3,6,1,2,]])
out: [[5 6 7 3 4]
[1 2 3 7 5]
[6 7 4 2 3]
[1 2 3 5 9]
[2 3 6 1 2]]

我可以通过np。argsort

获得排序数组的索引
print(np.argsort(example))
out:
[[3 4 0 1 2]
[0 1 2 4 3]
[3 4 2 0 1]
[0 1 2 3 4]
[3 0 4 1 2]]

我想用np。为了节省执行时间,因为我只需要在这个数组的每行中有3个排序的元素。我使用以下代码:

print(np.argpartition(example, 3, axis=1))
out: [[3 4 0 1 2]
[1 0 2 4 3]
[3 4 2 0 1]
[1 0 2 3 4]
[3 4 0 1 2]]

我期望每行的前三个索引将匹配排序数组中的索引,但事实并非如此,这不起作用。我不明白我做错了什么。

np.argpartition(example, k, axis=1)不返回前k个元素的排序数组。它只返回只有(k+1)个元素排序的索引。如果在输出中看到,只有第4个元素与argsort()

匹配如果你想要前三个元素排序,你必须给出一个k参数列表

index_array = np.argpartition(example, [0,1,2], axis=1)
print(np.take_along_axis(example,index_array, axis=1)) ##this will give you first 3 sorted elements

相关内容

  • 没有找到相关文章

最新更新