我无法编译此代码:
import numpy as np
import numba
from numba import jit, float64, complex128
import math
@jit(complex128[:](float64,float64[:],float64))
def GaborWavelet(omega, t, Gabor_coef):
c1 = 0.3251520240633*math.sqrt(omega)
c2 = -0.5*Gabor_coef
c3 = omega*0.187390625129278
res = np.array(c2*(t * c3)**2, dtype = np.complex128)
res.imag = omega*t
return c1*np.exp(res)
它提出:
编译正在返回到启用环举的对象模式,因为函数";GaborWavelet";类型推理失败,原因是:未找到签名的函数function((的实现:
数组(数组(float64,1d,C(,dtype=class(complex128((
有两种候选实现:-其中2个不匹配,原因是:函数"array"中的重载:文件:numba\core\ting\npydecl.py:行504。带参数:"(数组(float64,1d,C(,dtype=class(complex128((":由于实施引发特定错误而被拒绝:键入错误:在同质序列中不允许使用数组(float64,1d,C(
res = np.array(c2*(t * c3)**2, dtype = np.complex128)
^
我做错了什么?
如何编译这个代码(里面有numpy方法(?
Numba不支持您使用的两种功能,但支持等效选项:
-
通过
np.array(arr, dtype=type)
进行类型转换。请改用arr.astype(type)
。 -
为复杂数据类型设置
arr.imag=values
。请改用arr += values*1j
。
以下代码在我的机器上工作,应该会产生等效的结果:
import numpy as np
import numba
from numba import jit, float64, complex128
import math
@jit(complex128[:](float64,float64[:],float64))
def GaborWavelet(omega, t, Gabor_coef):
c1 = 0.3251520240633*math.sqrt(omega)
c2 = -0.5*Gabor_coef
c3 = omega*0.187390625129278
res = (c2*(t * c3)**2).astype(np.complex128)
res += omega*t*1j
return c1*np.exp(res)