加快在numpy数组中查找索引的速度



我有一个长度为15MM的1d numpy字符串数组(dtype='U'),称为ops,我需要在其中查找所有索引,在这些索引中我可以找到一个称为op的字符串83000次。

到目前为止,numpy赢得了比赛,但仍需要大约3个小时:indices = np.where(ops==op)我也试过np.unravel_index(np.where(ops.ravel()==op), ops.shape)[0][0],没有太大区别。

我正在尝试一种cython方法,使用与原始方法类似的随机数据,但它比numpys解决方案慢大约40倍。这是我的第一个cython代码,也许我可以改进它。Cython代码:

import numpy as np
cimport numpy as np
def get_ixs(np.ndarray data, str x, np.ndarray[int,mode="c",ndim=1] xind):
cdef int count, n, i
count = 0
n = data.shape[0]
i = 0
while i < n:
if (data[i] == x):
xind[count] = i
count += 1
i += 1
return xind[0:count]

如果使用相同的data多次调用get_ixs,最快的解决方案是将data预处理为dict,然后在查询字符串时获得O(1(查找(恒定时间(
dict的键是字符串x,该键的值是包含满足data[i] == x的索引的列表
这是代码:

import numpy as np
data = np.array(["toto", "titi", "toto", "titi", "tutu"])
indices = np.arange(len(data))
# sort data so that we can construct the dict by replacing list with ndarray as soon as possible (when string changes) to reduce memory usage
indices_data_sorted = np.argsort(data)  
data = data[indices_data_sorted]
indices = indices[indices_data_sorted]
# construct the dict str -> ndarray of indices (use ndarray for lower memory consumption)
dict_str_to_indices = dict()
prev_str = None
list_idx = []  # list to hold the indices for a given string
for i, s in zip(indices, data):
if s != prev_str:  
# the current string has changed so we can construct the ndarray and store it in the dict
if prev_str is not None:
dict_str_to_indices[prev_str] = np.array(list_idx, dtype="int32")
list_idx.clear()
prev_str = s
list_idx.append(i)

dict_str_to_indices[s] = np.array(list_idx, dtype="int32")  # add the ndarray for last string
def get_ixs(dict_str_to_indices: dict, x: str):
return dict_str_to_indices[x]
print(get_ixs(dict_str_to_indices, "toto"))
print(get_ixs(dict_str_to_indices, "titi"))
print(get_ixs(dict_str_to_indices, "tutu"))

输出:

[0 2]
[1 3]
[4]

如果用相同的dict_str_to_indices多次调用get_ixs,则它是最优渐近解(O(1(查找(。

最新更新