使用 numpy 进行 numba 编译有什么问题?



我无法编译此代码:

import numpy as np
import numba
from numba import jit, float64, complex128
import math
@jit(complex128[:](float64,float64[:],float64))
def GaborWavelet(omega, t, Gabor_coef):

c1 = 0.3251520240633*math.sqrt(omega)
c2 = -0.5*Gabor_coef
c3 = omega*0.187390625129278

res = np.array(c2*(t * c3)**2, dtype = np.complex128)

res.imag = omega*t

return c1*np.exp(res)

它提出:

编译正在返回到启用环举的对象模式,因为函数";GaborWavelet";类型推理失败,原因是:未找到签名的函数function((的实现:

数组(数组(float64,1d,C(,dtype=class(complex128((

有两种候选实现:-其中2个不匹配,原因是:函数"array"中的重载:文件:numba\core\ting\npydecl.py:行504。带参数:"(数组(float64,1d,C(,dtype=class(complex128((":由于实施引发特定错误而被拒绝:键入错误:在同质序列中不允许使用数组(float64,1d,C(

res = np.array(c2*(t * c3)**2, dtype = np.complex128)
^

我做错了什么?

如何编译这个代码(里面有numpy方法(?

Numba不支持您使用的两种功能,但支持等效选项:

  1. 通过np.array(arr, dtype=type)进行类型转换。请改用arr.astype(type)

  2. 为复杂数据类型设置arr.imag=values。请改用arr += values*1j

以下代码在我的机器上工作,应该会产生等效的结果:

import numpy as np
import numba
from numba import jit, float64, complex128
import math
@jit(complex128[:](float64,float64[:],float64))
def GaborWavelet(omega, t, Gabor_coef):
c1 = 0.3251520240633*math.sqrt(omega)
c2 = -0.5*Gabor_coef
c3 = omega*0.187390625129278
res = (c2*(t * c3)**2).astype(np.complex128)
res += omega*t*1j
return c1*np.exp(res)

最新更新