从掩码Python中获取切片索引



我有一个浮动的(N,)数组(arr(,但我只关心>=给定的CCD_ 3。我可以获得这样的面具:

mask = (arr >= threshold)

现在我想要一个对应切片索引的(N,2)数组。

例如,如果arr = [0, 0, 1, 1, 1, 0, 1, 1, 0, 1]threshold = 1,那么mask = [False, False, True, True, True, False, True, True, False, True]和我想要索引[ [2, 5], [6, 8], [9, 10] ](我可以将其用作arr[2:5], arr[6:8], arr[9:10]以获得其中arr >= threshold的段(。

目前,我有一个丑陋的for循环解决方案,在将相应的切片索引附加到列表之前,它遵循True的每一段。有没有一种更简洁易读的方法来实现这一结果?

您可以通过将掩码布尔值与其后续值进行比较来使用掩码计算开始索引和结束索引的列表。然后连接开始和结束以形成范围(所有这些都使用numpy方法矢量化(:

import numpy as np
arr       = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold = 1
mask      = arr >= threshold
starts    = np.argwhere(np.insert(mask[:-1],0,False)<mask)[:,0]
ends      = np.argwhere(np.append(mask[1:],False)<mask)[:,0]+1
indexes   = np.stack((starts,ends)).T
print(starts)  # [2 6 9]
print(ends)    # [5 8 10]
print(indexes)
[[ 2  5]
[ 6  8]
[ 9 10]]

如果你想在Python元组列表中得到结果:

indexes = list(zip(starts,ends))  # [(2, 5), (6, 8), (9, 10)]

如果你不需要(或不想(使用numpy,你可以使用itertools:中的groupby直接从arr中获取范围

from itertools import groupby
indexes = [ (t[1],t[-1]+1) for t,t[1:] in 
groupby(range(len(arr)),lambda i:[arr[i]>=threshold]) if t[0]]
print(indexes)
[(2, 5), (6, 8), (9, 10)]

您可以使用带有key参数的itertools groupby以及enumerate来获取分组。如果组值都是True,则可以取第一个和最后一个+1值。

from itertools import groupby
import numpy as np
arr = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold  = 1

idx = []
for group,data in groupby(enumerate((arr >= threshold)), key=lambda x:x[1]):
d = list(data)
if all(x[1]==True for x in d):
idx.append([d[0][0], d[-1][0]+1])

输出

[[2, 5], [6, 8], [9, 10]]

您可以使用np.flatnonzeronp.diff:的组合

indexes = np.flatnonzero(np.diff(np.append(arr >= threshold, 0))) + 1
indexes = list(zip(indexes[0::2], indexes[1::2]))

输出:

>>> indexes
[(2, 5), (6, 8), (9, 10)]

最新更新