如何使用numba更改数组中的索引



我有一个函数,我在其中执行一些操作,并希望使用numba来加速它。在我的代码中,使用高级索引更改数组中的值是不起作用的。我想他们在麻木的文件中确实这么说了。但是像numpy.put((这样的解决方法是什么?

这里有一个我想做的简单例子:

#example array
array([[ 0,  1,  2],
[ 0,  2, -1],
[ 0,  3, -1]])

使用在numba中工作的任何方法更改给定索引处的值。。。获得:在[0,0]、[1,2]、[2,1]时更改的值

#changed example array by given indexes with one given value (10)
array([[ 10,  1,  2],
[ 0,  2, 10],
[ 0,  10, -1]])

以下是我在python中所做的,但不使用numba:

indexList是一个Tuple,它与numpy.take((一起工作

这是python代码的工作示例,数组中的值更改为100。

x = np.zeros((151,151))
print(x.ndim)
indexList=np.array([[0,1,3],[0,1,2]])
indexList=tuple(indexList)
def change(xx,filter_list):
xx[filter_list] = 100
return xx
Z = change(x,indexList)

现在在函数上使用@jit:

@jit
def change(xx,filter_list):
xx[filter_list] = 100
return xx
Z = change(x,indexList)

编译正在返回到对象模式,并启用了环举,因为函数";改变";类型推理失败,原因是:找不到签名的函数function((的实现:setitem(数组(float64,2d,C(,UniTuple(阵列(int32,1d,C(x 2(,Literalint(

出现此错误。所以我需要一个变通办法。numba不支持numpy.put((。

我会很乐意提出任何想法。

谢谢

如果将indexList保留为数组不是问题,则可以将其与change函数中的for循环结合使用,使其与numba:兼容

indexList = np.array([[0,1,3],[0,1,2]]).T
@njit()
def change(xx, filter_list):
for y, x in filter_list:
xx[y, x] = 100
return xx
change(x, indexList)

注意,为了使yx坐标沿着第一轴,必须对indexList进行换位。换句话说,它必须具有(n, 2)而不是(2, n)的形状才能改变n点。实际上,它现在是一个坐标列表:[[0, 0],[1, 1],[3, 2]]

@mandulaj发布了前进的道路。在mandulaj给出答案之前,我走了一条有点不同的路。

使用此功能,我会收到一个弃用警告。。。所以最好的方法是使用@mandulaj,不要忘记转换indexList。

@jit
def change_arr(arr,idx,val): # change values in array by np index array to value
for i,item in enumerate(idx):
arr[item[0],item[1]]= val
return arr

最新更新