numpy数组的分区和排序


>> arr = [10, 11, 4, 3, 5, 7, 9, 2, 13]
>> np.partition(np.array(arr), -3)
array([ 9,  5,  4,  3,  2,  7, 10, 11, 13])
>> np.sort(np.partition(np.array(arr), -3)[-4:])
array([ 7, 10, 11, 13])
>> np.argpartition(np.array(arr), -3)
array([6, 4, 2, 3, 7, 5, 0, 1, 8], dtype=int64)
>> np.sort(np.argpartition(np.array(arr), -3)[-4:])
array([0, 1, 5, 8], dtype=int64)

这个代码中发生了什么?事实上,我已经看过文件,但无法从数字上理解这一点。

将一个简单的Python列表命名为arr是一种糟糕的做法。目前,这只是一个列表数组将在以后创建。

为了更好地理解正在发生的事情,最好划分代码将每个部分结果保存在单独的变量下。这就是我如何重写你的代码。

所以让我们从开始

lst = [10, 11, 4, 3, 5, 7, 9, 2, 13]

第二步是从此列表创建数组

arr1 = np.array(lst)

我决定将这个(以及下面的数组(命名为";arr";具有连续数字。

第三步是分区arr1;阈值";元素位于最后的第三个位置:

arr2 = np.partition(arr1, -3)

结果是:

array([ 9,  5,  4,  3,  2,  7, 10, 11, 13])

详细信息:

  • ;阈值";元素(10(位于从末端算起的第三个位置
  • 所有前面的元素都小于阈值
  • 以下所有元素都大于或等于阈值
  • 关于元素在"阈值";元素

然后你想得到arr2的最后4个元素:

arr3 = arr2[-4:]

毫不奇怪,结果是:

array([ 7, 10, 11, 13])

下一步是对它们进行排序:

arr4 = np.sort(arr3)

这一次没有任何变化,arr4的内容与arr3相同。

到目前为止,您已经完成了分区的实验,第二部分是argpartition的实验:

arr5 = np.argpartition(arr1, -3)

结果是:

array([6, 4, 2, 3, 7, 5, 0, 1, 8], dtype=int64)

它是arr1索引数组。

详细信息:

  • 从末尾开始的第三个元素(0(是";阈值";arr1中的元素(其值为10(
  • 所有先前的元素都是较小元素的索引(在arr1中(
  • 以下所有元素都是大于或等于元素的索引(在arr1中(

然后你得到arr5的最后4个元素:

arr6 = arr5[-4:]

获取:

array([5, 0, 1, 8], dtype=int64)

最后一步是对它们进行排序:

arr7 = np.sort(arr6)

得到(毫不奇怪(:

array([0, 1, 5, 8], dtype=int64)

仅此而已。

最新更新