numba:将数组按行相乘



我有numpy数组形状(2,5)和(2,),我想将它们相乘

a = np.array([[3,5,6,9,10],[4,7,8,11,12]])
b = np.array([-1,2])

From numpy: multiply arrays rowwise我知道这可以用numpy:

a * b[:,None] 

给出正确的输出

array([[ -3,  -5,  -6,  -9, -10],
[  8,  14,  16,  22,  24]])

但是在numba中,它不再工作了,我得到了一堆错误消息。

代码:

import numpy as np
from numba import njit
@njit()
def fct(a,b):
c = a * b[:,None]
return c
a = np.array([[3,5,6,9,10],[4,7,8,11,12]])
b = np.array([-1,2])
A = fct(a, b)
print(A)

我把这段代码放在一个名为numba_questionA.py的文件中。运行它会给出错误信息:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:

>>> getitem(array(int32, 1d, C), Tuple(slice<a:b>, none))

There are 22 candidate implementations:
- Of which 20 did not match due to:
Overload of function 'getitem': File: <numerous>: Line N/A.
With argument(s): '(array(int32, 1d, C), Tuple(slice<a:b>, none))':
No match.
- Of which 2 did not match due to:
Overload in function 'GetItemBuffer.generic': File: numbacoretypingarraydecl.py: Line 162.
With argument(s): '(array(int32, 1d, C), Tuple(slice<a:b>, none))':
Rejected as the implementation raised a specific error:
TypeError: unsupported array index type none in Tuple(slice<a:b>, none)
raised from numba_questionA.py
During: typing of intrinsic-call at numba_questionA.py
During: typing of static-get-item at numba_questionA.py
File "numba_questionA.py", line 6:
def fct(a,b):
c = a * b[:,None]
^

Numba说它不能使用None作为数组索引,所以你可以替换

b[:, None]

b.reshape(-1, 1)

然而,对于像a * b[:,None]这样的表达式,Numba不可能比Numpy快。

但是,如果数组确实大,则可以利用Numba的并行化:
@nb.njit(parallel=True)
def fct(a, b):
c = np.empty_like(a)
for i in nb.prange(a.shape[1]):
c[:, i] = a[:, i] * b
return c

使用guvectorize也是一种半自动广播数据的好方法。这样做的好处是可以针对不同的目标(cpuparallelcuda)进行编译。

对于像你的例子这样的小数组,并行化可能只会带来开销。

@nb.guvectorize(["void(int32[:], int32[:], int32[:])"], 
"(n), ()->(n)", nopython=True, target="cpu")
def fct(a, b, out):
out[:] = a * b
A = fct(a, b)

最新的Numba版本也可以自动推断数据类型,如果您愿意提供输出数组,并且只针对cpu目标编译,因此:

@nb.guvectorize("(n),()->(n)", nopython=True, target="cpu")
def fct(a, b, out):
out[:] = a * b
A = np.empty_like(a)
fct(a, b, A)

最新更新