>> arr = [10, 11, 4, 3, 5, 7, 9, 2, 13]
>> np.partition(np.array(arr), -3)
array([ 9, 5, 4, 3, 2, 7, 10, 11, 13])
>> np.sort(np.partition(np.array(arr), -3)[-4:])
array([ 7, 10, 11, 13])
>> np.argpartition(np.array(arr), -3)
array([6, 4, 2, 3, 7, 5, 0, 1, 8], dtype=int64)
>> np.sort(np.argpartition(np.array(arr), -3)[-4:])
array([0, 1, 5, 8], dtype=int64)
这个代码中发生了什么?事实上,我已经看过文件,但无法从数字上理解这一点。
将一个简单的Python列表命名为arr是一种糟糕的做法。目前,这只是一个列表,数组将在以后创建。
为了更好地理解正在发生的事情,最好划分代码将每个部分结果保存在单独的变量下。这就是我如何重写你的代码。
所以让我们从开始
lst = [10, 11, 4, 3, 5, 7, 9, 2, 13]
第二步是从此列表创建数组:
arr1 = np.array(lst)
我决定将这个(以及下面的数组(命名为";arr";具有连续数字。
第三步是分区arr1;阈值";元素位于最后的第三个位置:
arr2 = np.partition(arr1, -3)
结果是:
array([ 9, 5, 4, 3, 2, 7, 10, 11, 13])
详细信息:
- ;阈值";元素(10(位于从末端算起的第三个位置
- 所有前面的元素都小于阈值
- 以下所有元素都大于或等于阈值
- 关于元素在"阈值";元素
然后你想得到arr2的最后4个元素:
arr3 = arr2[-4:]
毫不奇怪,结果是:
array([ 7, 10, 11, 13])
下一步是对它们进行排序:
arr4 = np.sort(arr3)
这一次没有任何变化,arr4的内容与arr3相同。
到目前为止,您已经完成了分区的实验,第二部分是argpartition的实验:
arr5 = np.argpartition(arr1, -3)
结果是:
array([6, 4, 2, 3, 7, 5, 0, 1, 8], dtype=int64)
它是到arr1的索引数组。
详细信息:
- 从末尾开始的第三个元素(0(是";阈值";arr1中的元素(其值为10(
- 所有先前的元素都是较小元素的索引(在arr1中(
- 以下所有元素都是大于或等于元素的索引(在arr1中(
然后你得到arr5的最后4个元素:
arr6 = arr5[-4:]
获取:
array([5, 0, 1, 8], dtype=int64)
最后一步是对它们进行排序:
arr7 = np.sort(arr6)
得到(毫不奇怪(:
array([0, 1, 5, 8], dtype=int64)
仅此而已。