根据这里提供的解释1,我试图使用相同的想法来加快以下积分:
import scipy.integrate as si
from scipy.optimize import root, fsolve
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
print(do_integrate(integrand, 2.)[0])
在前面的参考中,我尝试使用numba/jit并以以下方式修改前面的块:
import numpy as np
import scipy.integrate as si
from scipy.optimize import root
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def jit_integrand_function(integrand_function):
jitted_function = numba.jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
do_integrate(integrand, 2.)
然而,这个实现给了我错误
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'a' in a function that will escape.
File "<ipython-input-16-3d98286a4be7>", line 20:
def integrand(t, *args):
<source elided>
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
^
During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)
During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)
错误来自于我在被积函数中使用scipy.optimize中的fsolve。
我想知道是否有解决这个错误的方法,以及在这种情况下是否可以使用scipy.optimize.fsolve和numba。
我为Minpack编写了一个名为NumbaMinpack
的小python包装器,它可以在numba编译的函数中调用:https://github.com/Nicholaswogan/NumbaMinpack.你可以用它来@njit
被积函数:
import scipy.integrate as si
from NumbaMinpack import hybrd, minpack_sig
from numba import njit, cfunc
import numpy as np
@cfunc(minpack_sig)
def f(x, fvec, args):
a = args[0]
fvec[0] = a * x[0]**2.0 - np.exp(-x[0]**2.0 / a)
funcptr = f.address # pointer to function
@njit
def integrand(t, *args):
a = args[0]
args_ = np.array(args)
x_init = np.array([1.0])
sol = hybrd(funcptr,x_init,args_)
c = sol[0][0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
print(do_integrate(integrand, 2.)[0])
在我的电脑上,上面的代码需要87µs,而纯python版本需要2920µs的