使用numpy add.at在整行中获取相同的值



我的问题如下,对于长度为数百万的二维数组np.add.是否比for上的简单迭代更快?

R将是值的二维数组,类似于[[171],[23,10]…],而nArray是一个矩阵25x25,它用np.zeros初始化,所以它只包含0(R中的所有值都在0到24之间,所以没有错误(

for point in R:
nArray[tuple(point)] += 1

我在np.add.at中遇到的问题是,我在一行的所有单元格上都有相同的值,而不是得到相同的结果。第0行->35525 35525。。。35525

第1行->20782078。。。2078

这个值与第一个实现中相应行的第一列上的值不同

np.add.at(nArray, R, 1)
In [182]: R = np.random.randint(0,25,(100,2)) 
In [197]: A0=np.zeros((25,25),int)                                                                     
In [198]: for point in R: A0[tuple(point)] += 1                                                        
In [199]: A1=np.zeros((25,25),int)                                                                     
In [200]: np.add.at(A1,(R[:,0],R[:,1]),1)                                                              
In [201]: np.allclose(A0,A1)                                                                           
Out[201]: True

我用了(R[:,0],R[:,1])而不是R,因为文档上说:

如果第一个操作数有多个维度,则索引可以是类似数组的索引对象或切片对象的元组。

关于您的tolist修复:

In [216]: A2=np.zeros((25,25),int)                                                                     
In [217]: np.add.at(A2,R.T.tolist(),1)                                                                 
/usr/local/bin/ipython3:1: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
#!/usr/bin/python3
In [218]: np.allclose(A0,A2)                                                                           
Out[218]: True
In [219]: A2=np.zeros((25,25),int)                                                                     
In [220]: np.add.at(A2,tuple(R.T.tolist()),1)                                                          
In [221]: np.allclose(A0,A2)                                                                           
Out[221]: True
In [208]: np.__version__                                                                               
Out[208]: '1.18.2'

显然,只有当R是一个numpy数组时才会发生这种情况,如果它首先被转换为列表(例如,通过.tolist(((,它的工作方式就像一样

最新更新