我有一个长度为15MM的1d numpy字符串数组(dtype='U')
,称为ops
,我需要在其中查找所有索引,在这些索引中我可以找到一个称为op
的字符串83000次。
到目前为止,numpy赢得了比赛,但仍需要大约3个小时:indices = np.where(ops==op)
我也试过np.unravel_index(np.where(ops.ravel()==op), ops.shape)[0][0]
,没有太大区别。
我正在尝试一种cython方法,使用与原始方法类似的随机数据,但它比numpys解决方案慢大约40倍。这是我的第一个cython代码,也许我可以改进它。Cython代码:
import numpy as np
cimport numpy as np
def get_ixs(np.ndarray data, str x, np.ndarray[int,mode="c",ndim=1] xind):
cdef int count, n, i
count = 0
n = data.shape[0]
i = 0
while i < n:
if (data[i] == x):
xind[count] = i
count += 1
i += 1
return xind[0:count]
如果使用相同的data
多次调用get_ixs
,最快的解决方案是将data
预处理为dict
,然后在查询字符串时获得O(1(查找(恒定时间(
dict的键是字符串x
,该键的值是包含满足data[i] == x
的索引的列表
这是代码:
import numpy as np
data = np.array(["toto", "titi", "toto", "titi", "tutu"])
indices = np.arange(len(data))
# sort data so that we can construct the dict by replacing list with ndarray as soon as possible (when string changes) to reduce memory usage
indices_data_sorted = np.argsort(data)
data = data[indices_data_sorted]
indices = indices[indices_data_sorted]
# construct the dict str -> ndarray of indices (use ndarray for lower memory consumption)
dict_str_to_indices = dict()
prev_str = None
list_idx = [] # list to hold the indices for a given string
for i, s in zip(indices, data):
if s != prev_str:
# the current string has changed so we can construct the ndarray and store it in the dict
if prev_str is not None:
dict_str_to_indices[prev_str] = np.array(list_idx, dtype="int32")
list_idx.clear()
prev_str = s
list_idx.append(i)
dict_str_to_indices[s] = np.array(list_idx, dtype="int32") # add the ndarray for last string
def get_ixs(dict_str_to_indices: dict, x: str):
return dict_str_to_indices[x]
print(get_ixs(dict_str_to_indices, "toto"))
print(get_ixs(dict_str_to_indices, "titi"))
print(get_ixs(dict_str_to_indices, "tutu"))
输出:
[0 2]
[1 3]
[4]
如果用相同的dict_str_to_indices
多次调用get_ixs
,则它是最优渐近解(O(1(查找(。