numpy.argmax/argmin的更快替代方案,后者速度较慢



我在Python中使用了很多argminargmax

不幸的是,功能非常缓慢。

我在周围做了一些搜索,我能找到的最好的是这里:

http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/

def fastest_argmax(array):
    array = list( array )
    return array.index(max(array))

不幸的是,这个解决方案的速度仍然只有np.max的一半,我认为我应该能够找到与np.max一样快的东西。

x = np.random.randn(10)
%timeit np.argmax( x )
10000 loops, best of 3: 21.8 us per loop
%timeit fastest_argmax( x )    
10000 loops, best of 3: 20.8 us per loop

注意,我将此应用于的Pandas DataFrame Group

例如

%timeit grp2[ 'ODDS' ].agg( [ fastest_argmax ] )
100 loops, best of 3: 8.8 ms per loop
%timeit grp2[ 'ODDS' ].agg( [ np.argmax ] )
100 loops, best of 3: 11.6 ms per loop

grp2[ 'ODDS' ].head()如下所示:

EVENT_ID   SELECTION_ID        
104601100  4367029       682508    3.05
                         682509    3.15
                         682510    3.25
                         682511    3.35
           5319660       682512    2.04
                         682513    2.08
                         682514    2.10
                         682515    2.12
                         682516    2.14
           5510310       682520    4.10
                         682521    4.40
                         682522    4.50
                         682523    4.80
                         682524    5.30
           5559264       682526    5.00
                         682527    5.30
                         682528    5.40
                         682529    5.50
                         682530    5.60
           5585869       682533    1.96
                         682534    1.97
                         682535    1.98
                         682536    2.02
                         682537    2.04
           6064546       682540    3.00
                         682541    2.74
                         682542    2.76
                         682543    2.96
                         682544    3.05
104601200  4916112       682548    2.64
                         682549    2.68
                         682550    2.70
                         682551    2.72
                         682552    2.74
           5315859       682557    2.90
                         682558    2.92
                         682559    3.05
                         682560    3.10
                         682561    3.15
           5356995       682564    2.42
                         682565    2.44
                         682566    2.48
                         682567    2.50
                         682568    2.52
           5465225       682573    1.85
                         682574    1.89
                         682575    1.91
                         682576    1.93
                         682577    1.94
           5773661       682588    5.00
                         682589    4.40
                         682590    4.90
                         682591    5.10
           6013187       682592    5.00
                         682593    4.20
                         682594    4.30
                         682595    4.40
                         682596    4.60
104606300  2489827       683438    4.00
                         683439    3.90
                         683440    3.95
                         683441    4.30
                         683442    4.40
           3602724       683446    2.16
                         683447    2.32
Name: ODDS, Length: 65, dtype: float64

事实证明,np.argmax非常快,但只有使用本机numpy数组。对于国外数据,几乎所有的时间都花在了转换上:

In [194]: print platform.architecture()
('64bit', 'WindowsPE')
In [5]: x = np.random.rand(10000)
In [57]: l=list(x)
In [123]: timeit numpy.argmax(x)
100000 loops, best of 3: 6.55 us per loop
In [122]: timeit numpy.argmax(l)
1000 loops, best of 3: 729 us per loop
In [134]: timeit numpy.array(l)
1000 loops, best of 3: 716 us per loop

我称你的函数为";低效的";因为它首先将所有内容转换为list,然后迭代2次(实际上是3次迭代+list构造)。

我本来打算提出这样一个只迭代一次的东西:

def imax(seq):
    it=iter(seq)
    im=0
    try: m=it.next()
    except StopIteration: raise ValueError("the sequence is empty")
    for i,e in enumerate(it,start=1):
        if e>m:
            m=e
            im=i
    return im

但是,您的版本会更快,因为它迭代了很多次,但都是用C代码,而不是Python代码。即使考虑到在转换上也花费了大量时间,C的速度也快得多:

In [158]: timeit imax(x)
1000 loops, best of 3: 883 us per loop
In [159]: timeit fastest_argmax(x)
1000 loops, best of 3: 575 us per loop
In [174]: timeit list(x)
1000 loops, best of 3: 316 us per loop
In [175]: timeit max(l)
1000 loops, best of 3: 256 us per loop
In [181]: timeit l.index(0.99991619010758348)  #the greatest number in my case, at index 92
100000 loops, best of 3: 2.69 us per loop

因此,进一步加快这一速度的关键知识是了解序列中的数据本身是哪种格式(例如,您是否可以省略转换步骤或使用/编写该格式的另一种本机功能)。

顺便说一句,使用aggregate(max_fn)而不是agg([max_fn])可能会获得一些加速。

对于那些返回第一个最小值的索引的短numpy免费片段:

def argmin(a):
    return min(range(len(a)), key=lambda x: a[x])
a = [6, 5, 4, 1, 1, 3, 2]
argmin(a)  # returns 3

你能发布一些代码吗?这是我电脑上的结果:

x = np.random.rand(10000)
%timeit np.max(x)
%timeit np.argmax(x)

输出:

100000 loops, best of 3: 7.43 µs per loop
100000 loops, best of 3: 11.5 µs per loop

最新更新