numpy排序错误与numba jit装饰器一致



我正在尝试实现numba。调用numpy的Jit函数。sort函数对numpy数组进行排序,但由于"已检测到从nopython编译路径回退到对象模式编译路径"而失败。我的代码如下:

gg = numpy.array ([[1,0,2],[1,2,1]],dtype = np.dtype((int,int)))
@nb.jit(nb.void(numba.int32[:,:]))
def kk (gg):
np.sort(gg)

我也尝试过njit模式,但也得到了错误:

"Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1mNo implementation of function Function(<intrinsic stub>) found for signature:
>>> stub(array(int32, 2d, A))
There are 2 candidate implementations:
[1m  - Of which 2 did not match due to:
Intrinsic of function 'stub': File: numbacoreoverload_glue.py: Line 35.
With argument(s): '(array(int32, 2d, A))':"

我已经检查了numba文档,因为它显示了numpy。支持排序函数。我的代码有问题吗?或者排序函数只在对象模式下工作?

Numba不支持二维数组排序。要解决这个问题,可以遍历感兴趣的维度并对每一行或每一列进行排序。但是,这将比直接使用np.sort慢。

import numba as nb
import numpy as np
@nb.njit(nb.int32[:, :](nb.int32[:, :]))
def sort_by_second_axis(arr):
# Make a copy so we do not modify original.
arr = arr.copy()
for i in range(arr.shape[0]):
arr[i].sort()
return arr

下面是一个使用的例子:

prng = np.random.RandomState(42)
x = (prng.uniform(size=16) * 10).astype("int32").reshape(4, 4)
np.array_equal(np.sort(x), sort_by_second_axis(x))

如果您使用@nb.jit(nb.void(nb.int32[:]))(即,将其应用于一维数组),则警告会消失。Numba似乎不支持无python模式的非平面数组上的np.sort。这就是为什么它必须退回到对象模式。

import numba as nb
import numpy as np
@nb.jit(nb.void(nb.int32[:]))
def sortme(arr):
np.sort(arr)

我还会质疑在这种情况下是否需要numba。np.sort是用C语言实现的,并且已经编译好了。它非常快,从我的测试来看,numba有点慢。

import numba as nb
import numpy as np
@nb.njit(nb.int32[:](nb.int32[:]))
def sort_numba(arr):
return np.sort(arr)
prng = np.random.RandomState(seed=42)
x = (prng.rand(100_000) * 1_000).astype("int32")
assert np.array_equal(np.sort(x), sort_numba(x))
%timeit np.sort(x)
# 3.73 ms ± 45.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit sort_numba(x)
# 3.98 ms ± 37 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

最新更新