Numba 函数与类型参数的函数使用无效



我正在使用Numba非python模式和一些NumPy函数。

@njit
def invert(W, copy=True):
'''
Inverts elementwise the weights in an input connection matrix.
In other words, change the from the matrix of internode strengths to the
matrix of internode distances.
If copy is not set, this function will *modify W in place.*
Parameters
----------
W : np.ndarray
weighted connectivity matrix
copy : bool
Returns
-------
W : np.ndarray
inverted connectivity matrix
'''
if copy:
W = W.copy()
E = np.where(W)
W[E] = 1. / W[E]
return W

在此函数中,W是一个矩阵。但是我得到了以下错误。它可能与W[E] = 1. / W[E]行有关。

File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, A), tuple(array(int64, 1d, C) x 2))

那么使用NumPy和Numba的正确方法是什么?我知道NumPy在矩阵计算方面做得很好。在这种情况下,NumPy是否足够快,以至于Numba不再提供加速?

正如 FBruzzesi 在注释中提到的,代码没有编译的原因是你使用了"花哨的索引",因为W[E]中的Enp.where的输出,是数组的元组。(这解释了稍微隐晦的错误消息:Numba 不知道如何使用getitem,即当其中一个输入是元组时,它不知道如何在括号中找到某些东西。

Numba 实际上支持在单个维度上进行花哨的索引(也称为"高级索引"(,只是不支持多个维度。在您的情况下,这允许进行简单的修改:首先使用ravel几乎无成本地使您的数组成为一维数组,然后应用转换,然后便宜地reshape回来。

@njit
def invert2(W, copy=True):
if copy:
W = W.copy()
Z = W.ravel()
E = np.where(Z)
Z[E] = 1. / Z[E]
return Z.reshape(W.shape)

但这仍然比需要的要慢,因为计算通过不必要的中间数组传递,而不是在遇到非零值时立即修改数组。简单地做一个循环会更快:

@njit 
def invert3(W, copy=True): 
if copy: 
W = W.copy() 
Z = W.ravel() 
for i in range(len(Z)): 
if Z[i] != 0: 
Z[i] = 1/Z[i] 
return Z.reshape(W.shape) 

无论W的尺寸如何,此代码都有效。如果我们知道W是二维的,那么我们可以直接迭代二维,但由于两者具有相似的性能,我选择更一般的路线。

在我的计算机上,假设一个 300 x 300 的数组W其中大约一半的条目是 0,并且invert是没有 Numba 编译的原始函数,计时是:

In [80]: %timeit invert(W)                                                                   
2.67 ms ± 49.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [81]: %timeit invert2(W)                                                                  
519 µs ± 24.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [82]: %timeit invert3(W)                                                                  
186 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

所以 Numba 给了我们一个相当大的加速(在它已经运行过一次以消除编译时间之后(,尤其是在代码以 Numba 可以利用的高效循环风格重写之后。

最新更新