Python与cython的快速余弦距离



我想尽可能快地余弦距离计算scipy.spatial.distance.cosine,所以我尝试使用numpy

def alt_cosine(x,y):
return 1 - np.inner(x,y)/np.sqrt(np.dot(x,x)*np.dot(y,y))

我试过胞苷

from libc.math cimport sqrt
def alt_cosine_2(x,y):
return 1 - np.inner(x,y)/sqrt(np.dot(x,x)*np.dot(y,y))

并逐渐得到改进(在长度为50的numpy阵列上测试(

>>> cosine() # ... make some timings
5.27526156300155e-05 # mean calculation time for one loop
>>> alt_cosine() 
9.913400815003115e-06
>>> alt_cosine_2()
7.0269494536660205e-06

最快的方法是什么不幸的是,我无法为alt_cosine_2指定变量类型,我将对类型为np.float32的numpy数组使用此函数

有一种观点认为,借助cython或numba无法加快numpy的功能。但这并不完全正确:numpy的目标是为各种场景提供出色的性能,但这也意味着在特殊场景中的性能有些不完美。

有了特定的场景,您就有机会改进numpy的性能,即使这意味着要重写numpy的一些功能。例如,在这种情况下,我们可以使用cython将函数加速因子4,使用numba将因子8。

让我们从您的版本作为基线开始(请参阅答案末尾的列表(:

>>>%timeit cosine(x,y)   # scipy's
31.9 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
>>>%timeit np_cosine(x,y)  # your numpy-version
4.05 µs ± 19.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np_cosine_fhtmitchell(x,y)  # @FHTmitchell's version
4 µs ± 53.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>>%timeit np_cy_cosine(x,y)
2.56 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

所以我看不到@FHTmitchell版本的改进,但在其他方面与您的时间安排没有什么不同。

您的向量只有50个元素,因此实际计算大约需要200-300ns:其他一切都是调用函数的开销。减少开销的一种可能性是在cython:的帮助下,每只手"内联"这些功能

%%cython 
from libc.math cimport sqrt
import numpy as np
cimport numpy as np
def cy_cosine(np.ndarray[np.float64_t] x, np.ndarray[np.float64_t] y):
cdef double xx=0.0
cdef double yy=0.0
cdef double xy=0.0
cdef Py_ssize_t i
for i in range(len(x)):
xx+=x[i]*x[i]
yy+=y[i]*y[i]
xy+=x[i]*y[i]
return 1.0-xy/sqrt(xx*yy)

这导致:

>>> %timeit cy_cosine(x,y)
921 ns ± 19.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

不错!我们可以尝试通过以下更改来放弃一些安全性(运行时检查+ieee-754标准(,从而挤出更多的性能:

%%cython  -c=-ffast-math
...
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
def cy_cosine_perf(np.ndarray[np.float64_t] x, np.ndarray[np.float64_t] y):
...

这导致:

>>> %timeit cy_cosine_perf(x,y)
828 ns ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

即另外10%,这意味着几乎比numpy版本快5倍。

还有另一种工具可以提供类似的功能/性能-numba:

import numba as nb
import numpy as np
@nb.jit(nopython=True, fastmath=True)
def nb_cosine(x, y):
xx,yy,xy=0.0,0.0,0.0
for i in range(len(x)):
xx+=x[i]*x[i]
yy+=y[i]*y[i]
xy+=x[i]*y[i]
return 1.0-xy/np.sqrt(xx*yy)

这导致:

>>> %timeit nb_cosine(x,y)
495 ns ± 5.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

与最初的numpy版本相比,速度提高了8。

numba可以更快有一些原因:Cython在运行时处理数据的步幅,这阻止了一些优化(例如矢量化(。Numba似乎处理得更好。

但这里的差异完全是由于numba的开销减少:

%%cython  -c=-ffast-math
import numpy as np
cimport numpy as np
def cy_empty(np.ndarray[np.float64_t] x, np.ndarray[np.float64_t] y):
return x[0]*y[0]
import numba as nb
import numpy as np
@nb.jit(nopython=True, fastmath=True)
def nb_empty(x, y):
return x[0]*y[0]
%timeit cy_empty(x,y)
753 ns ± 6.81 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_empty(x,y)
456 ns ± 2.47 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

numba的开销几乎减少了2倍!

正如@max9111所指出的,numpy内联了其他jit函数,但它也能够用很少的开销调用一些numpy函数,因此以下版本(用dot替换inner(:

@nb.jit(nopython=True, fastmath=True)
def np_nb_cosine(x,y):
return 1 - np.dot(x,y)/sqrt(np.dot(x,x)*np.dot(y,y))
>>> %timeit np_nb_cosine(x,y)
605 ns ± 5.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 

仅慢约10%。


请注意,上述比较仅对具有50个元素的向量有效。对于更多的元素,情况则完全不同:numpy版本使用了点积的并行mkl(或类似(实现,将轻松击败我们的简单尝试。

这就引出了一个问题:针对特定大小的输入优化代码真的值得吗?有时答案是肯定的,有时答案是否定的。

如果可能的话,我会得到numba+dot解决方案,它对小输入非常快,但对大输入也具有mkl实现的全部能力。


还有一个细微的区别:第一个版本返回np.float64-对象,cython和numba版本返回Python浮点。


列表:

from scipy.spatial.distance import cosine
import numpy as np
x=np.arange(50, dtype=np.float64)
y=np.arange(50,100, dtype=np.float64)
def np_cosine(x,y):
return 1 - inner(x,y)/sqrt(np.dot(x,x)*dot(y,y))
from numpy import inner, sqrt, dot
def np_cosine_fhtmitchell(x,y):
return 1 - inner(x,y)/sqrt(np.dot(x,x)*dot(y,y))
%%cython
from libc.math cimport sqrt
import numpy as np
def np_cy_cosine(x,y):
return 1 - np.inner(x,y)/sqrt(np.dot(x,x)*np.dot(y,y))

加速这种代码的懒惰方法:

  1. 使用numexprPython模块
  2. 使用numbaPython模块
  3. 使用NumPy函数的SciPy等价物

不幸的是,这些技巧都不适用于您,因为:

  1. dotinner未在numexpr中实现
  2. numba(类似Cython(不会加快对NumPy函数的调用
  3. CCD_ 15和CCD_ 16在CCD_

也许你最好的选择是尝试在不同的底层LA库(例如LAPACK、BLAS、OpenBLAS等(和编译选项(例如多线程等(下编译numpy,看看哪种组合对你的用例最有效。

祝你好运!

最新更新