如何使cython函数接受浮点或双数组输入?



>假设我有以下(MCVE...(cython函数

cimport cython
from scipy.linalg.cython_blas cimport dnrm2

cpdef double func(int n, double[:] x):
cdef int inc = 1
return dnrm2(&n, &x[0], &inc)

然后,我无法在np.float32数组x上调用它。

我怎样才能让func接受double[:]float[:],并交替打电话给dnrm2snrm2?我目前唯一的解决方案是有两个函数,这会创建大量重复的代码。

您可以使用融合类型。请注意,以下内容无法在我的系统上编译,因为ddotsdot显然需要 5 个参数:

# cython: infer_types=True
cimport cython
from scipy.linalg.cython_blas cimport ddot, sdot
ctypedef fused anyfloat:
double
float
cpdef anyfloat func(int n, anyfloat[:] x):
cdef int inc = 1
if anyfloat is double:
return ddot(&n, &x[0], &inc)
else:
return sdot(&n, &x[0], &inc)

最新更新