在numpy库中,可以将列表传递到numpy.searchsorted
函数中,通过该函数,它一次搜索一个不同的列表元素,并返回一个与保持顺序所需的索引大小相同的数组。然而,如果对两个列表都进行排序,似乎会浪费性能。例如:
m=[1,3,5,7,9]
n=[2,4,6,8,10]
numpy.searchsorted(m,n)
会返回[1,2,3,4,5]
,这是正确的答案,但看起来它的复杂性为O(n ln(m((,因此,如果简单地循环通过m,并有某种指向n的指针,那么复杂性似乎更像O(n+m(?NumPy中是否有某种函数可以做到这一点?
AFAIK,如果不对输入进行额外假设(例如,整数较小且有界(,则仅使用Numpy在线性时间内不可能做到这一点。另一种解决方案是使用Numba手动进行合并:
import numba as nb
# Note: Numba requires a function signature with well defined array types
@nb.njit('int64[:](int64[::1], int64[::1])')
def search_both_sorted(a, b):
i, j = 0, 0
result = np.empty(b.size, np.int64)
while i < a.size and j < a.size:
if a[i] < b[j]:
i += 1
else:
result[j] = i
j += 1
for k in range(j, b.size):
result[k] = i
return result
a, b = np.cumsum(np.random.randint(0, 100, (2, 1000000)).astype(np.int64), axis=1)
result = search_both_sorted(a, b)
更快的实现在于使用无分支方法,以便在a
和b
大约相同大小时消除分支错误预测的开销(尤其是在随机/不可预测输入上(。此外,当b
小时,O(n log m)
算法可能更快,因此在这种情况下使用np.searchsorted
是非常有效的,正如@MichaelSzczesny所指出的。请注意,np.searchsorted
的Numba实现可能比Numpy的慢一点,因此最好选择Numpy实现。这是优化版本:
@nb.njit('int64[:](int64[::1], int64[::1])')
def search_both_sorted_opt_numba(a, b):
sa, sb = a.size, b.size
# Choose the best algorithm
if sb < sa * 0.15:
# Use a version with branches because `a[i] < b[j]`
# should be most of the time true.
i, j = 0, 0
result = np.empty(b.size, np.int64)
while i < a.size and j < b.size:
if a[i] < b[j]:
i += 1
else:
result[j] = i
j += 1
for k in range(j, b.size):
result[k] = i
else:
# Use a branchless approach to avoid miss-predictions
i, j = 0, 0
result = np.empty(b.size, np.int64)
while i < a.size and j < b.size:
tmp = a[i] < b[j]
result[j] = i
i += tmp
j += ~tmp
for k in range(j, b.size):
result[k] = i
return result
def search_both_sorted_opt(a, b):
sa, sb = a.size, b.size
# Choose the best algorithm
if 2 * sb * np.log2(sa) < sa + sb:
return np.searchsorted(a, b)
else:
return search_both_sorted_opt_numba(a, b)
searchsorted: 19.1 ms
snp_search: 11.8 ms
search_both_sorted: 6.5 ms
search_both_sorted_branchless: 4.3 ms
优化的无分支Numba实现比searchsorted
快大约4.4倍,考虑到searchsorted
的代码已经高度优化,这是非常好的。当a
和b
很大时,由于缓存位置,它可能会更快。
您可以使用sortednp,不幸的是,它没有提供太多的灵活性。在下面的代码片段中,我使用了它的合并跟踪索引,但它产生了三个数组,使用的内存是所需内存的四倍,但它比searchsorted快。
import numpy as np
import sortednp as snp
a = np.cumsum(np.random.rand(1000000))
b = np.cumsum(np.random.rand(1000000))
def snp_search(a,b):
m, (ib, ia) = snp.merge(b, a, indices=True)
return ib - np.arange(len(ib))
assert(np.all(snp_search(a,b) == np.searchsorted(a,b)))
np.searchsorted(a, b); #58 ms
snp_search(a,b); # 22ms
np.searchsorted
已经考虑到了这一点,从源代码中可以看出
/*
* Updating only one of the indices based on the previous key
* gives the search a big boost when keys are sorted, but slightly
* slows down things for purely random ones.
*/
if (cmp(last_key_val, key_val)) {
max_idx = arr_len;
}
else {
min_idx = 0;
max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
}
这里CCD_ 14用于对阵列执行二进制搜索。如果last_key_val < key_val
,则只有max_idx
被重置为数组长度,但min_idx
保持在其当前值,即二进制搜索在与前一个关键字相同的下边界处开始。