我正试图编译一个接受numpy数组和元组的函数形式为*arg的参数使用numba。
import numba as nb
import numpy as np
@nb.njit(cache=True)
def myfunc(t, *p):
val = 0
for j in range(0, len(p), 2):
val += p[j]*np.exp(-p[j+1]*t)
return val
T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc = myfunc(T, *pars)
然而,我得到这个结果
In [1]: run numba_test.py
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
~/Programs/my-python/numba_test.py in <module>
12
13 T = np.arange(12)
---> 14 mfunc = myfunc(T, 1.0, 2.0, 3.0, 4.0)
...
...
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function iadd>) with argument(s) of type(s): (Literal[int](0), array(float64, 1d, C))
Known signatures:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
* parameterized
In definition 0:
All templates rejected with literals.
...
...
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at /home/cshugert/Programs/my-python/numba_test.py (9)
File "numba_test.py", line 9:
def myfunc(t, *p):
<source elided>
for j in range(0, len(p), 2):
val += p[j]*np.exp(-p[j+1]*t)
^
Numba确实支持使用元组,所以我认为我可以在jit编译器中添加一些签名。但是,我不确定确切地说该放什么。numba编译器会是这样吗无法处理将*args作为参数的函数?我能做些什么让我的功能正常工作吗?
让我们再次看到错误消息
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function iadd>) with argument(s)
of type(s): (Literal[int](0), array(float64, 1d, C))
Known signatures:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
* parameterized
错误是针对<built-in function iadd>
,即+
。如果您查看错误,它不是由于*args
的传递,而是由于以下语句:
val += p[j]*np.exp(-p[j+1]*t)
基本上,在提到的+
的所有兼容类型中,它不支持将integer
添加到array
(有关更多信息,请参阅错误消息和已知签名(。
您可以通过使用np.zeros
将val
设置为零数组来解决此问题(请参阅此处的文档(。
import numba as nb
import numpy as np
@nb.njit
def myfunc(t, *p):
val = np.zeros(12) #<------------ Set it as an array of zeros
for j in range(0, len(p), 2):
val += p[j]*np.exp(-p[j+1]*t)
return val
T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc_val = myfunc(T, *pars)
你可以在这个谷歌Colab笔记本上查看这里的代码。