修剪cython中的numpy数组



目前我有以下cython函数,它修改用零填充的numpy数组的条目,以求和非零值。在返回数组之前,我想对其进行修剪并删除所有非零条目。目前,我使用numpy函数myarray = myarray[~np.all(myarray == 0, axis=1)]来实现这一点。我想知道(通常(是否有一种更快的方法可以使用Cython/C函数而不是依赖python/numpy。这是我的脚本中Python交互的最后一部分(通过使用to%%cython -a进行检查(。但我真的不知道如何处理这个问题。一般来说,我不知道最终数组中非零元素的先验数量。

cdef func():
np.ndarray[np.float64_t, ndim=2] myarray = np.zeros((lenpropen, 6)) 
"""
computations
"""
myarray = myarray[~np.all(myarray == 0, axis=1)]
return myarray

感谢@Jérôme Richard的评论。基于此(如果我的理解是正确的(,我尝试实现擦除-删除习语。下面给出的示例代码。

myarray = np.zeros((5000,6))
myarray[2] = [1,1,1,1,1,1]
@cython.boundscheck(False)  # Deactivate bounds checking                                                                  
@cython.wraparound(False)   # Deactivate negative indexing.                                                               
@cython.cdivision(True)     # Deactivate division by 0 checking.
cdef erase_remove( np.ndarray[np.float64_t, ndim=2] myarray):
cdef int idx 
cdef int cursor = 0
cdef int length_arr = 5000
for idx in range(5000):

if myarray[idx,0]!=0 and myarray[idx,1]!=0 and myarray[idx,2]!=0 and myarray[idx,3]!=0 and myarray[idx,4]!=0 and  myarray[idx,5]!=0:
myarray[cursor,0] = myarray[idx,0]
myarray[cursor,1] = myarray[idx,1]
myarray[cursor,2] = myarray[idx,2]
myarray[cursor,3] = myarray[idx,3]
myarray[cursor,4] = myarray[idx,4]
myarray[cursor,5] = myarray[idx,5]
cursor = cursor +1
else:
continue
return  myarray[0:cursor]    
start = timer()
myarray= erase_remove(myarray)
end = timer()
print("final", myarray)
print("time", end-start)

这产生输出

final [[1. 1. 1. 1. 1. 1.]]
time 1.1235475540161133e-05

与相比

myarray = np.zeros((5000,6))
print(myarray)
myarray[2] = [1,1,1,1,1,1]
print(myarray)
start = timer()
myarray = myarray[~np.all(myarray == 0, axis=1)]
end = timer()
print(myarray)
print("time", end-start)

产生输出

[[1. 1. 1. 1. 1. 1.]]
time 0.0006445050239562988

如果最高维度总是包含少量元素,如6,则您的代码不是最好的。

首先,myarray == 0np.all~创建临时数组,这会在需要写入和读取时引入一些额外的开销。开销取决于临时数组的开销,最大的开销是myarray == 0

此外,Numpy调用执行一些不需要的检查,Cython无法删除这些检查。这些检查引入了恒定的时间开销。因此,对于小的输入数组来说,is可能相当大,但对于大的输入数组则不然。

此外,如果np.all的代码知道最后一个维度的确切大小,那么它可以更快,而这里的情况并非如此。事实上,np.all的循环理论上可以展开,因为最后一个维度很小。不幸的是,Cython没有优化Numpy调用,并且Numpy是为可变输入大小编译的,因此在编译时是未知的

最后,如果lenpropen很大,计算可以并行化(否则不会更快,实际上可能更慢(。但是,请注意,并行实现需要分两个步骤进行计算:np.all(myarray == 0, axis=1)需要并行计算,然后您可以创建结果数组,并通过并行计算myarray[~result]来写入它。按顺序,您可以通过就地筛选行直接覆盖myarray,然后生成筛选行的视图。这种模式被称为删除习惯用法。请注意,这假设数组是连续的

总之,一个更快的实现包括在myarray上编写2个嵌套循环,其中最内层的循环迭代次数恒定。关于lenpropen的大小,您可以使用基于擦除-移除习惯用法的顺序就地实现,也可以使用具有两个步骤(和一个临时数组(的并行就地实现。

相关内容

  • 没有找到相关文章

最新更新