使用条件(numpy.where)时更快的numpy数组索引?



我有一个形状为(50000000,3)的巨大numpy数组,我使用:

x = array[np.where((array[:,0] == value) | (array[:,1] == value))]

来获取我想要的那部分数组。但是这条路似乎很慢。是否有一种更有效的方法来使用numpy执行相同的任务?

np.where高度优化我怀疑有人能写出比上一个Numpy版本中实现的代码更快的代码(免责声明:我是优化它的人)。也就是说,这里的主要问题不是太多的np.where,而是创建临时布尔数组的条件。不幸的是,这是在Numpy中实现的方式,只要您只使用具有相同输入布局的Numpy,就没有太多可做的。

效率不高的一个原因是输入数据布局效率不高。实际上,假设array使用默认的行主顺序连续存储在内存中,array[:,0] == value将在内存中读取数组的每3项1项。由于CPU缓存的工作方式(即。缓存线,预取等),2/3的内存带宽被浪费。实际上,输出布尔数组也需要写入,并且由于页面错误,填充新创建的数组有点慢。。注意array[:,1] == value肯定会从RAM重新加载数据由于输入的大小(大多数CPU缓存无法容纳)。内存慢与CPU和缓存的计算速度相比,它要慢得多。这个被称为"记忆墙"的问题在几十年前就已经被观察到,而且预计不会很快得到解决。还要注意,逻辑-or还将创建一个从RAM读/写/到RAM的新数组。更好的数据布局是在内存中连续的(3, 50000000)转置数组(注意np.transpose不会产生连续数组)。

另一个解释性能问题的原因是Numpy倾向于不被优化为在非常小的轴上操作

.一个主要的解决方案是尽可能以转置的方式创建输入。另一个解决方案是编写一个Numba或Cython代码。下面是一个非转置输入的实现:

# Compilation for the most frequent types. 
# Please pick the right ones so to speed up the compilation time. 
@nb.njit(['(uint8[:,::1],uint8)', '(int32[:,::1],int32)', '(int64[:,::1],int64)', '(float64[:,::1],float64)'], parallel=True)
def select(array, value):
n = array.shape[0]
mask = np.empty(n, dtype=np.bool_)
for i in nb.prange(n):
mask[i] = array[i, 0] == value or array[i, 1] == value
return mask
x = array[select(array, value)]

请注意,我使用了并行实现,因为or运算符在Numba中是次优的(唯一的解决方案似乎是使用本机代码或Cython),而且在某些平台(如计算服务器)上,RAM不能被一个线程完全饱和。还请注意,对于select的结果,使用array[np.where(select(array, value))[0]]可能更快。实际上,如果结果是随机的或非常小,那么np.where可以更快,因为它针对这些情况进行了布尔索引无法执行的特殊优化。请注意,np.where并没有在Numba函数的上下文中进行特别优化,因为Numba使用自己的Numpy函数实现,并且它们有时对大型数组没有那么优化。更快的实现包括并行创建x,但这对于Numba来说不是微不足道的,因为输出项的数量事先不知道,线程必须知道在哪里写数据,更不用说Numpy已经相当快了,只要输出可预测