在numpy中,找到数组中非零序列最短的数组的计算效率最高的方法



假设我有一个数组

import numpy as np 
z = np.array(
    [
     [1, 1, 0, 0, 0, 0],
     [1, 1, 1, 1, 1, 0],
     [1, 1, 1, 0, 0, 0],
     [1, 1, 1, 1, 1, 1],
    ]
)

其中,1从每个数组的左侧开始,0从右侧开始(如果有的话(。对于许多应用程序,这就是填充数组的方式,以便在数组数组中每个数组的长度相同。

对于这样一个数组,我如何得到非零的最短序列。

在这种情况下,最短的序列是长度为2的第一个数组。

显而易见的答案是迭代每个数组并找到第一个零的索引,但我觉得可能有一种方法更能利用numpy的c处理。

具有5000×5000阵列的基准:

 74.3 ms  Dani
 33.8 ms  user19077881
  2.6 ms  Kelly1
  1.4 ms  Kelly2

我的Kelly1是从右上到左下的O(m+n(鞍背搜索:

def Kelly1(z):
    m, n = z.shape
    j = n - 1
    for i in range(m):
        while not z[i, j]:
            j -= 1
            if j < 0:
                return 0
    return j + 1

(Michael Szczesny说,如果我没记错的话,使用Numba可以把速度提高150倍。不过,我自己没有能力测试。(

我的Kelly2是一个O(m log n(水平二进制搜索,使用NumPy检查列是否充满非零:

def Kelly2(z):
    m, n = z.shape
    lo, hi = 0, n
    while lo < hi:
        mid = (lo + hi) // 2
        if z[:, mid].all():
            lo = mid + 1
        else:
            hi = mid
    return lo

(使用bisectkey可能会更短,但我现在没有Python 3.10要测试。(

注意:Dani和user19077881返回的结果不同:任意行中非零数最少的行,或者非零数最小的行。我听从了丹妮的指挥,因为这是公认的答案。这其实并不重要,因为你可以很快地从另一个结果中计算出一个结果(分别通过找到列或行中第一个零的索引(。

完整的基准代码(在线试用!(:

import numpy as np
from timeit import timeit
import random
m, n = 5000, 5000
def genz():
    lo = random.randrange(n*5//100, n//3)
    return np.array(
        [
            [1]*ones + [0]*(n-ones)
            for ones in random.choices(range(lo, n+1), k=m)
        ]
    )
def Dani(z):
    return np.count_nonzero(z, axis=1).min()
def user19077881(z):
    z_sums = z.sum(axis = 1)
    z_least = np.argmin(z_sums)
    return z_least
def Kelly1(z):
    m, n = z.shape
    j = n - 1
    for i in range(m):
        while not z[i, j]:
            j -= 1
            if j < 0:
                return 0
    return j + 1
def Kelly2(z):
    m, n = z.shape
    lo, hi = 0, n
    while lo < hi:
        mid = (lo + hi) // 2
        if z[:, mid].all():
            lo = mid + 1
        else:
            hi = mid
    return lo
funcs = Dani, user19077881, Kelly1, Kelly2
for _ in range(3):
    z = genz()
    for f in funcs:
        t = timeit(lambda: f(z), number=1)
        print('%5.1f ms ' % (t * 1e3), f.__name__)
    print()

使用np.count_nonzero+np.min:

res = np.count_nonzero(z, axis=1).min()
print(res)

输出

2

函数count_nonzero返回如下数组:

[2 5 3 6]

然后简单地找到最小值。

如果需要该行的索引,请改用np.argmin

如果您想知道哪个子数组的零最少,那么您可以使用:

z_sums = z.sum(axis = 1)
z_least = np.argmin(z_sums)

最新更新