以numba no-python模式对numpy数组排序



Numba文档建议以下代码应该编译

@njit()
def accuracy(x,y,z):
x.argsort(axis=1)
# compare accuracy, this code works without the above line  
accuracy_y = int(((np.equal(y, x).mean())*100)%100)
accuracy_z = int(((np.equal(z, x).mean())*100)%100)
return accuracy_y,accuracy_z

它在x.argsort()上失败了,我也尝试了以下有和没有轴参数

np.argsort(x)
np.sort(x)
x.sort()

然而我得到以下编译失败错误(或类似):

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sort at 0x000001B3B2CD2EE0>) found for signature:

>>> sort(array(int64, 2d, C))

There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'sort': File: numbacoretypingnpydecl.py: Line 665.
With argument(s): '(array(int64, 2d, C))':
No match.
During: resolving callee type: Function(<function sort at 0x000001B3B2CD2EE0>)

File "accuracy.py", line 148:
def accuracy(x,lm,sfn):
<source elided>
# accuracy
np.sort(x)
^

我在这里错过了什么?

如果适合您的用例,您也可以考虑使用guvectorize。这样做的好处是可以指定要排序的轴。在一个以上维度上的排序可以通过在不同轴上重复调用来完成。

@guvectorize("(n)->(n)")
def sort_array(x, out):
out[:] = np.sort(x)

使用稍微不同的示例数组,其中也有无序的列。

arr = np.array([
[6,5,4],
[3,2,1],
[9,8,7]],
)
sort_array(arr, out, axis=0)
sort_array(out, out, axis=1)

节目:

array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])

制作一个简单的包装器将允许一次对任意数量的维度进行排序。我不认为Numba的过载支持guvectorize,否则你甚至可以使用它来使np.sort在你的编译函数中工作,而不必改变任何东西。

https://numba.pydata.org/numba-doc/dev/extending/overloading-guide.html

测试与Numpy比较的输出:

for _ in range(20):

arr = np.random.randint(0, 99, (9,9))
# numba sorting
out_nb = np.empty_like(arr)
sort_array(arr, out_nb, axis=0)
sort_array(out_nb, out_nb, axis=1)
# numpy sorting
out_np = np.sort(arr, axis=0)
out_np = np.sort(out_np, axis=1)
np.testing.assert_array_equal(out_nb, out_np)

感谢您的评论!下面的函数对一个2d数组进行排序!

from numba import njit
import numpy  as np
@njit()
def sort_2d_array(x):
n,m=np.shape(x)
for row in range(n):
x[row]=np.sort(x[row])
return x
arr=np.array([[3,2,1],[6,5,4],[9,8,7]])
y=sort_2d_array(arr)
print(y)