检查 numpy 数组是否排序



我有一个numpy数组,我喜欢检查它是否排序。

>>> a = np.array([1,2,3,4,5])
array([1, 2, 3, 4, 5])
np.all(a[:-1] <= a[1:])

例子:

is_sorted = lambda a: np.all(a[:-1] <= a[1:])
>>> a = np.array([1,2,3,4,9])
>>> is_sorted(a)
True
>>> a = np.array([1,2,3,4,3])
>>> is_sorted(a)
False

使用 NumPy 工具:

np.all(np.diff(a) >= 0)

但 numpy 解决方案都是 O(n)。

如果你想要快速代码和关于未排序数组的非常快速的结论:

import numba
@numba.jit
def is_sorted(a):
    for i in range(a.size-1):
         if a[i+1] < a[i] :
               return False
    return True
          

这是随机数组上的 O(1)(平均值)。

效率低下但易于键入的解决方案:

(a == np.sort(a)).all()

为了完整起见,下面可以找到 O(log n) 迭代解决方案。递归版本速度较慢,并且以大矢量大小崩溃。但是,它仍然比使用np.all(a[:-1] <= a[1:])的本机numpy慢,这很可能是由于现代CPU优化。O(log n) 更快的唯一情况是在"平均"随机情况下,或者如果它是"几乎"排序的。如果你怀疑你的数组已经完全排序,那么np.all会更快。

def is_sorted(a):
    idx = [(0, a.size - 1)]
    while idx:
        i, j = idx.pop(0) # Breadth-First will find almost-sorted in O(log N)
        if i >= j:
            continue
        elif a[i] > a[j]:
            return False
        elif i + 1 == j:
            continue
        else:
            mid = (i + j) >> 1 # Division by 2 with floor
            idx.append((i, mid))
            idx.append((mid, j))
    return True
is_sorted2 = lambda a: np.all(a[:-1] <= a[1:])

以下是结果:

# Already sorted array - np.all is super fast
sorted_array = np.sort(np.random.rand(1000000))
%timeit is_sorted(sorted_array)
659 ms ± 3.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit is_sorted2(sorted_array)
431 µs ± 35.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# Here I included the random in each command so we need to substract it's timing
%timeit np.random.rand(1000000)
6.08 ms ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit is_sorted(np.random.rand(1000000))
6.11 ms ± 58.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Without random part, it took 6.11 ms - 6.08 ms = 30µs per loop
%timeit is_sorted2(np.random.rand(1000000))
6.83 ms ± 75.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Without random part, it took 6.83 ms - 6.08 ms = 750µs per loop

Net,O(n)向量优化代码比O(log n)算法更好,除非你将运行>1亿个元素数组。

最新更新