获取当前值(flood fill)的邻居索引的最快方法



我需要找到一种快速的方法来获取具有当前值的邻居的索引

例如:

arr = [0, 0, 0, 1, 0, 1, 1, 1, 1, 0]
indicies = func(arr, 6)
# [5, 6, 7, 8]

第6个元素的值为1,所以我需要包含第6个元素以及所有具有相同值

的相邻元素的完整切片它就像洪水填充算法的一部分。有没有办法在numpy中快速完成?有二维阵列的方法吗?

编辑

让我们看一些性能测试:

import numpy as np
import random
np.random.seed(1488)
arr = np.zeros(5000)
for x in np.random.randint(0, 5000, size = 100):
arr[x:x+50] = 1

我将比较@Ehsan的函数:

def func_Ehsan(arr, idx):
change = np.insert(np.flatnonzero(np.diff(arr)), 0, -1)
loc = np.searchsorted(change, idx)
start = change[max(loc-1,0)]+1 if loc<len(change) else change[loc-1]
end = change[min(loc, len(change)-1)]
return (start, end)
change = np.insert(np.flatnonzero(np.diff(arr)), 0, -1)
def func_Ehsan_same_arr(arr, idx):
loc = np.searchsorted(change, idx)
start = change[max(loc-1,0)]+1 if loc<len(change) else change[loc-1]
end = change[min(loc, len(change)-1)]
return (start, end)

与我的纯python函数:

def my_func(arr, index):

val = arr[index]
size = arr.size

end = index + 1
while end < size and arr[end] == val:
end += 1
start = index - 1
while start > -1 and arr[start] == val:
start -= 1

return start + 1, end

看一看:

np.random.seed(1488)
%timeit my_func(arr, np.random.randint(0, 5000))
# 42.4 µs ± 700 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

np.random.seed(1488)
%timeit func_Ehsan(arr, np.random.randint(0, 5000))
# 115 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
np.random.seed(1488)
%timeit func_Ehsan_same_arr(arr, np.random.randint(0, 5000))
# 18.1 µs ± 953 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

是否有一种方法可以使用numpy相同的逻辑,没有C模块/Cython/Numba/python循环?让它更快!

我不知道如何用numpy解决这个问题,但是如果你使用pandas,你可能会得到你想要的结果:

import pandas as pd
df=pd.DataFrame(arr,columns=["data"])
df["new"] = df["data"].diff().ne(0).cumsum()
[{i[0]:j.index.tolist()} for i,j in df.groupby(["data","new"],sort=False)]

输出:

[{0: [0, 1, 2]}, {1: [3]}, {0: [4]}, {1: [5, 6, 7, 8]}, {0: [9]}]

主要问题是Numpy目前还不能有效地解决这个问题. A"快速找到第一个值索引";或任何类似的惰性函数需要调用来有效地解决这个问题。然而,尽管这个特性早在10年前就已经讨论过了,但在Numpy中仍然没有这个特性。更多信息请看这篇文章。我不指望很快会有任何改变。在此之前,相对较大的数组的最佳解决方案似乎是使用相对较慢的纯python循环和较慢的Numpy调用/访问的迭代解决方案。

除此之外,加速计算的一个解决方案是处理小块. 下面是一个实现:
def my_func_opt1(arr, index):
val = arr[index]
size = arr.size
chunkSize = 128

end = index + 1
while end < size:
chunk = arr[end:end+chunkSize]
locations = (chunk != val).nonzero()[0]
if len(locations) > 0:
foundCount = locations[0]
end += foundCount
break
else:
end += len(chunk)
start = index
while start > 0:
chunk = arr[max(start-chunkSize,0):start]
locations = (chunk != val).nonzero()[0]
if len(locations) > 0:
foundCount = locations[-1]
start -= chunkSize - 1 - foundCount
break
else:
start -= len(chunk)

return start, end

以下是我的机器的性能结果:

func_Ehsan:   53.8  µs ± 449 ns  per loop (mean ± std. dev. of 7 runs, 10000 loops each)
my_func:      17.5  µs ± 97.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
my_func_opt1:  7.31 µs ± 52.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

问题是结果有点偏,因为np.random.randint实际花费2.01µs。如果基准测试中没有包含这个Numpy调用,结果如下:

func_Ehsan:   51.8  µs
my_func:      15.5  µs
my_func_opt1:  5.31 µs
结果,my_func_opt1的速度约为的3倍my_func。编写更快的代码是非常困难的,因为任何Numpy调用都会在我的机器上引入相对较大的0.5-1.0µs开销,无论数组大小如何(例如: 由于内部检查导致的空数组。

以下是对加速操作感兴趣的人可以使用Numba的有用信息。

最简单的解决方案是使用Numba的JIT更具体地说就是添加decorator。这个解决方案也非常快。

@nb.njit('UniTuple(i8,2)(f8[::1], i8)')
def my_func_opt2(arr, index):
val = arr[index]
size = arr.size

end = index + 1
while end < size and arr[end] == val:
end += 1
start = index - 1
while start > -1 and arr[start] == val:
start -= 1

return start + 1, end

在我的机器上my_func_opt2只需要0.63µs(排除随机调用)。因此,my_func_opt2快25倍my_func。我非常怀疑是否有更快的解决方案,因为在我的机器上,任何Numpy调用至少需要0.5µs,而一个空的Numba函数调用需要0.25µs。


除此之外,请注意arr包含双精度值,计算起来非常昂贵。如果可以的话,使用整数应该会更快。另外,请注意,0和1值的数组可以存储在int8值中,它占用的内存少8倍,并且通常计算速度更快(由于CPU缓存,数组越小计算速度越快)。您可以在创建数组时指定类型:np.zeros(5000, dtype=np.int8)

这是一个numpy解决方案。我认为你可以再努力一点来改进它:

def func(arr, idx):
change = np.insert(np.flatnonzero(np.diff(arr)), 0, -1)
loc = np.searchsorted(change, idx)
start = change[max(loc-1,0)]+1 if loc<len(change) else change[loc-1]
end = change[min(loc, len(change)-1)]
return np.arange(start, end)

样本输出:

indices = func(arr, 6)
#[5 6 7 8]

如果您的arr(相对于数组大小)变化很少,并且您正在寻找相同数组中的多个索引搜索,则这将特别快。否则,我们会想到更快的解决方案。

性能比较:如果多次对同一个数组执行操作,只需像这样将第一行移出函数,以避免重复。

change = np.insert(np.flatnonzero(np.diff(arr)), 0, -1)
def func(arr, idx):
loc = np.searchsorted(change, idx)
start = change[max(loc-1,0)]+1 if loc<len(change) else change[loc-1]
end = change[min(loc, len(change)-1)]
return np.arange(start, end)

对于与OP相同的输入:

np.random.seed(1488)
%timeit func_OP(arr, np.random.randint(0, 5000))
#23.5 µs ± 631 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
np.random.seed(1488)
%timeit func_Ehsan(arr, np.random.randint(0, 5000))
#7.89 µs ± 113 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
np.random.seed(1488)
%timeit func_Jérôme_opt1(arr, np.random.randint(0, 5000))
#12.1 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit func_Jérôme_opt2(arr, np.random.randint(0, 5000))
#3.45 µs ± 179 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
与<<p> strong> func_Ehsan 最快(不包括Numba)。请再次注意,这些函数的性能在数组的变化次数、数组大小和函数在同一数组上被调用的次数上有很大的不同。当然,Numba比所有这些都快(几乎比func_Ehsan快2倍)。如果你要多次运行它,在O(n)中构建组,并使用哈希映射到O(1)中的索引。