有没有更快的方法来计算运行中位数



是否有内置函数或更快的方法来计算以下内容?

x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])
def half_life_idx(x):
    middle = sum(x) / 2
    for idx, val in enumerate(x):
        middle = middle - val 
        if middle <= 0:         
            break
    return idx
half_life_idx(x)
>> 1

换句话说,我想找到x的指数,其中累积总和为 x[0:index+1] >= sum(x)/2

您可以结合使用 cumsumsearchsorted 方法来实现更快的版本:

def half_life_idx_ww(x):
    cs = np.cumsum(x)
    middle = cs[-1]/2
    return cs.searchsorted(middle)

例如

In [167]: x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])
In [168]: half_life_idx(x), half_life_idx_ww(x)
Out[168]: (1, 1)
In [169]: w = np.random.gamma(1.5, size=200)
In [170]: half_life_idx(w), half_life_idx_ww(w)
Out[170]: (99, 99)

另一种方法是np.argmax参见此示例的函数f1

import numpy as np
def f0(x):
    #leermeester's orginal method
    middle = sum(x) / 2
    for idx, val in enumerate(x):
        middle = middle - val 
        if middle <= 0:         
            break
    return idx
def f1(x):
    #my method using argmax
    cs = x.cumsum()
    return np.argmax(cs>cs[-1]/2)
def f2(x):
    #Warren Weckesser's method using searchsorted
    cs = np.cumsum(x)
    middle = cs[-1]/2
    return cs.searchsorted(middle)

以下是每种方法的一些基准:

print("small run")
x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])
%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))

print("larger run")
x = np.random.rand(int(1.0E3))
%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))
print("very large run")
x = np.random.rand(int(1.0E6))
%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))
#a print to make sure all give the same result
print(f0(x),f1(x),f2(x))

基准测试结果

small run
2.48 µs ± 41.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.47 µs ± 57.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.7 µs ± 49.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
larger run
184 µs ± 2.59 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.2 µs ± 51.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
5.01 µs ± 14.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
very large run
185 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.3 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.64 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
500260 500260 500260

结论:对于非常小的数组,您的方法是最快的,但是对于较大的数组,它比建议的答案慢得多,Warren的解决方案始终比我的快30%。

最新更新