用ctypes从numba调用fortran



考虑这个fortran模块,在文件test.f90

module mymod
use iso_c_binding, only: c_double
implicit none
contains
subroutine addstuff(a,b,c) bind(c,name='addstuff_wrap')
real(c_double), intent(in) :: a, b
real(c_double), intent(out) :: c
c = a + b
end subroutine
end module

可以用CCD_ 2编译。我可以用从python调用它

import ctypes as ct
mylib = ct.CDLL('test.so')
addstuff.argtypes = [ct.POINTER(ct.c_double), ct.POINTER(ct.c_double), ct.POINTER(ct.c_double)]
addstuff.restype = None
a = ct.c_double(1.0)
b = ct.c_double(2.0)
c = ct.c_double()
addstuff(ct.byref(a),ct.byref(b),ct.byref(c))
print(c.value)

这将返回正确的答案3.0。然而,我想从numbajitted函数调用它

from numba import njit
@njit
def test(a, b):
c = ct.c_double()
addstuff(ct.byref(ct.c_double(a)), 
ct.byref(ct.c_double(b)), 
ct.byref(c))
return c.value
test(1.0, 2.0)

但这行不通。它返回错误

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'c_double' of type Module(<module 'ctypes' from ...)
File "<ipython-input-14-f8fb94981395>", line 3:
def test(a, b):
c = ct.c_double()
^

有人知道附近的工作吗?这很烦人,因为numba确实声称它支持c_double类型。

以下是解决方案:

import ctypes as ct
from numba import njit
mylib = ct.CDLL('test.so')
addstuff.argtypes = [ct.c_void_p, ct.c_void_p, ct.c_void_p]
addstuff.restype = None
@njit
def test(a, b):
aa = np.array(a,np.float64)
bb = np.array(b,np.float64)
c = np.array(0.0,np.float64)
addstuff(aa.ctypes.data, 
bb.ctypes.data, 
c.ctypes.data)
return c.item()

最新更新