我在Python中使用了很多argmin
和argmax
。
不幸的是,功能非常缓慢。
我在周围做了一些搜索,我能找到的最好的是这里:
http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/
def fastest_argmax(array):
array = list( array )
return array.index(max(array))
不幸的是,这个解决方案的速度仍然只有np.max
的一半,我认为我应该能够找到与np.max
一样快的东西。
x = np.random.randn(10)
%timeit np.argmax( x )
10000 loops, best of 3: 21.8 us per loop
%timeit fastest_argmax( x )
10000 loops, best of 3: 20.8 us per loop
注意,我将此应用于的Pandas DataFrame Group
例如
%timeit grp2[ 'ODDS' ].agg( [ fastest_argmax ] )
100 loops, best of 3: 8.8 ms per loop
%timeit grp2[ 'ODDS' ].agg( [ np.argmax ] )
100 loops, best of 3: 11.6 ms per loop
grp2[ 'ODDS' ].head()
如下所示:
EVENT_ID SELECTION_ID
104601100 4367029 682508 3.05
682509 3.15
682510 3.25
682511 3.35
5319660 682512 2.04
682513 2.08
682514 2.10
682515 2.12
682516 2.14
5510310 682520 4.10
682521 4.40
682522 4.50
682523 4.80
682524 5.30
5559264 682526 5.00
682527 5.30
682528 5.40
682529 5.50
682530 5.60
5585869 682533 1.96
682534 1.97
682535 1.98
682536 2.02
682537 2.04
6064546 682540 3.00
682541 2.74
682542 2.76
682543 2.96
682544 3.05
104601200 4916112 682548 2.64
682549 2.68
682550 2.70
682551 2.72
682552 2.74
5315859 682557 2.90
682558 2.92
682559 3.05
682560 3.10
682561 3.15
5356995 682564 2.42
682565 2.44
682566 2.48
682567 2.50
682568 2.52
5465225 682573 1.85
682574 1.89
682575 1.91
682576 1.93
682577 1.94
5773661 682588 5.00
682589 4.40
682590 4.90
682591 5.10
6013187 682592 5.00
682593 4.20
682594 4.30
682595 4.40
682596 4.60
104606300 2489827 683438 4.00
683439 3.90
683440 3.95
683441 4.30
683442 4.40
3602724 683446 2.16
683447 2.32
Name: ODDS, Length: 65, dtype: float64
事实证明,np.argmax
非常快,但只有使用本机numpy数组。对于国外数据,几乎所有的时间都花在了转换上:
In [194]: print platform.architecture()
('64bit', 'WindowsPE')
In [5]: x = np.random.rand(10000)
In [57]: l=list(x)
In [123]: timeit numpy.argmax(x)
100000 loops, best of 3: 6.55 us per loop
In [122]: timeit numpy.argmax(l)
1000 loops, best of 3: 729 us per loop
In [134]: timeit numpy.array(l)
1000 loops, best of 3: 716 us per loop
我称你的函数为";低效的";因为它首先将所有内容转换为list,然后迭代2次(实际上是3次迭代+list构造)。
我本来打算提出这样一个只迭代一次的东西:
def imax(seq):
it=iter(seq)
im=0
try: m=it.next()
except StopIteration: raise ValueError("the sequence is empty")
for i,e in enumerate(it,start=1):
if e>m:
m=e
im=i
return im
但是,您的版本会更快,因为它迭代了很多次,但都是用C代码,而不是Python代码。即使考虑到在转换上也花费了大量时间,C的速度也快得多:
In [158]: timeit imax(x)
1000 loops, best of 3: 883 us per loop
In [159]: timeit fastest_argmax(x)
1000 loops, best of 3: 575 us per loop
In [174]: timeit list(x)
1000 loops, best of 3: 316 us per loop
In [175]: timeit max(l)
1000 loops, best of 3: 256 us per loop
In [181]: timeit l.index(0.99991619010758348) #the greatest number in my case, at index 92
100000 loops, best of 3: 2.69 us per loop
因此,进一步加快这一速度的关键知识是了解序列中的数据本身是哪种格式(例如,您是否可以省略转换步骤或使用/编写该格式的另一种本机功能)。
顺便说一句,使用aggregate(max_fn)
而不是agg([max_fn])
可能会获得一些加速。
对于那些返回第一个最小值的索引的短numpy免费片段:
def argmin(a):
return min(range(len(a)), key=lambda x: a[x])
a = [6, 5, 4, 1, 1, 3, 2]
argmin(a) # returns 3
你能发布一些代码吗?这是我电脑上的结果:
x = np.random.rand(10000)
%timeit np.max(x)
%timeit np.argmax(x)
输出:
100000 loops, best of 3: 7.43 µs per loop
100000 loops, best of 3: 11.5 µs per loop