我想用numpy为一些统计对象做一个集合字典,简化状态如下。
分别有一个标量数组,标记为a = np.array([n1,n2,n3...])
和 2D 阵列作为b = np.array([[q1_1,q1_2],[q2_1,q2_2],[q3_1,q3_2]...])
对于a
ni
的每个元素,我想挑选出b
中包含ni
的所有元素qi([qi_1,qi_2])
,并用key
进行dict
,ni
收集它们。
为此,我记录了一个笨拙的方法(假设确定了a
和b
(到以下代码中,如下所示:
import numpy as np
a = np.array([i+1 for i in range(100)])
b = np.array([[2*i+1,2*(i+1)] for i in range(50)])
dict = {}
for i in a: dict[i] = [j for j in b if i in j]
毫无疑问,当a
和b
很大时,这将非常缓慢。 有没有其他有效的方法来替代上述方法? 寻求您的帮助!
感谢您的想法。它可以完全解决我的问题。你的核心概念是比较 a 和 b,并得到布尔数组作为结果。因此,将数组 b 的布尔索引用于构建字典要快得多。遵循这个想法,我以自己的方式重写您的代码
dict = {}
for item in a:
index_left, index_right = (b[:,0]==item), (b[:,1]==item)
index = np.logical_or(index_left, index_right)
dict[item] = dict[index]
这些代码仍然不比你的快,但即使在大的a和b中也可以避免"记忆错误"(例如a=100000和b=200000(
Numpy 数组允许元素比较:
equal = b[:,:,np.newaxis]==a #np.newaxis to broadcast
# if one of the two is equal, we will include this element
index = np.logical_or(equal[:,0], equal[:,1])
# indexing by a boolean array to get the result
dictionary = {i: b[index[:,i]] for i in range(len(a))}
最后要说一句:你确定要使用字典吗?这样你就会失去很多麻木优势
编辑,回答您的评论:
如果 a 和 b 这么大,相等,大小为 10^10,则产生 8*10^10 字节,大约为 72 G。这就是您收到此错误的原因。
您应该问的主要问题是:我真的需要这么大的数组吗?如果是,你确定字典也不会很大吗?
这个问题可以通过不一次计算所有内容来解决,但在n
的情况下,n
应该在你的情况下约为 72/16(内存中的比例(。但是,将n大一点可能会加快该过程:
stride = int(len(b)/n)
dictionary = {}
for i in range(n):
#splitting b into several parts
equal = b[n*stride:(n+1)*stride,:,np.newaxis]==a
index = np.logical_or(equal[:,0], equal[:,1])
dictionary.update( {i: b[index[:,i]] for i in range(len(a))})