我有两个numpy布尔数组(a
和b
)。我需要找出它们的元素中有多少是相等的。目前,我使用len(a) - (a ^ b).sum()
,但据我所知,xor操作会创建一个全新的numpy数组。如何在不创建不必要的临时数组的情况下有效地实现这种期望的行为?
我试过使用numexpr,但我不能让它正常工作。它不支持True是1,False是0的概念,所以我必须使用ne.evaluate("sum(where(a==b, 1, 0))")
,这大约需要两倍的时间。
编辑:我忘了提到,其中一个数组实际上是另一个大小不同的数组的视图,两个数组都应该被认为是不可变的。这两个阵列都是二维的,大小往往在25x40左右。
是的,这是我的程序的瓶颈,值得优化。
在我的机器上,这更快:
(a == b).sum()
如果你不想使用任何额外的存储空间,我建议你使用numba。我对它不太熟悉,但这似乎很管用。我在让Cython获取布尔NumPy数组时遇到了一些问题。
from numba import autojit
def pysumeq(a, b):
tot = 0
for i in xrange(a.shape[0]):
for j in xrange(a.shape[1]):
if a[i,j] == b[i,j]:
tot += 1
return tot
# make numba version
nbsumeq = autojit(pysumeq)
A = (rand(10,10)<.5)
B = (rand(10,10)<.5)
# do a simple dry run to get it to compile
# for this specific use case
nbsumeq(A, B)
如果你没有numba,我建议使用@user2357112 的答案
编辑:刚刚有一个Cython版本在运行,这是.pyx
文件。我会同意的。
from numpy cimport ndarray as ar
cimport numpy as np
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
def cysumeq(ar[np.uint8_t,ndim=2,cast=True] a, ar[np.uint8_t,ndim=2,cast=True] b):
cdef int i, j, h=a.shape[0], w=a.shape[1], tot=0
for i in xrange(h):
for j in xrange(w):
if a[i,j] == b[i,j]:
tot += 1
return tot
首先可以跳过A*B步骤:
>>> a
array([ True, False, True, False, True], dtype=bool)
>>> b
array([False, True, True, False, True], dtype=bool)
>>> np.sum(~(a^b))
3
如果你不介意摧毁阵列a或b,我不确定你会比这个更快:
>>> a^=b #In place xor operator
>>> np.sum(~a)
3
如果问题是分配和释放,请维护一个输出数组,并告诉numpy每次都将结果放在那里:
out = np.empty_like(a) # Allocate this outside a loop and use it every iteration
num_eq = np.equal(a, b, out).sum()
不过,只有当输入始终是相同的维度时,这才会起作用。如果输入大小不同,您可能可以制作一个大数组,并根据每次调用所需的大小分割出一个部分,但我不确定这会减慢您的速度。
改进IanH的回答,还可以通过向ndarray提供mode="c"
,从Cython中访问numpy数组中的底层C数组。
from numpy cimport ndarray as ar
cimport numpy as np
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
cdef int cy_sum_eq(ar[np.uint8_t,ndim=2,cast=True,mode="c"] a, ar[np.uint8_t,ndim=2,cast=True,mode="c"] b):
cdef int i, j, h=a.shape[0], w=a.shape[1], tot=0
cdef np.uint8_t* adata = &a[0, 0]
cdef np.uint8_t* bdata = &b[0, 0]
for i in xrange(h):
for j in xrange(w):
if adata[j] == bdata[j]:
tot += 1
adata += w
bdata += w
return tot
这在我的机器上比IanH的Cython版本快了大约40%,我发现重新排列循环内容在这一点上似乎没有太大区别,可能是由于编译器优化。在这一点上,可以潜在地链接到使用SSE优化的C函数,以便执行此操作,并将adata
和bdata
作为uint8_t*
的传递