查找NumPy数组中大于N的值的开始/停止索引范围



假设我有一个NumPy数组:

x = np.array([2, 3, 4, 0, 0, 1, 1, 4, 6, 5, 8, 9, 9, 4, 2, 0, 3])

对于x >= 2中的所有值,我需要找到x >=2的连续值(即,不计算一个大于或等于2的单个值的运行(的开始/停止索引。然后,我对x >= 3x >=4、…、。。。,x >= x.max()输出应该是NumPy数组三列(第一列是最小值,第二列是包含起始索引,第三列是停止索引(,看起来像:

[[2,  0,  2],
 [2,  7, 14],
 [3,  1,  2],
 [3,  7, 13],
 [4,  7, 13],
 [5,  8, 12],
 [6, 10, 12],
 [8, 10, 12],
 [9, 11, 12]
]

天真地,我可以浏览每个唯一的值,然后搜索开始/停止索引。然而,这需要在x上进行多次传递。完成此任务的最佳NumPy矢量化方式是什么?有没有一种解决方案不需要对数据进行多次传递?

更新

我意识到我还需要计算单个实例。所以,我的输出应该是:

[[2,  0,  2],
 [2,  7, 14],
 [2, 16, 16],  # New line needed
 [3,  1,  2],
 [3,  7, 13],
 [3, 16, 16],  # New line needed
 [4,  2,  2],  # New line needed
 [4,  7, 13],
 [5,  8, 12],
 [6,  8,  8],  # New line needed
 [6, 10, 12],
 [8, 10, 12],
 [9, 11, 12]
]

这里有另一个解决方案(我认为可以改进(:

import numpy as np
from numpy.lib.stride_tricks import as_strided
x = np.array([2, 3, 4, 0, 0, 1, 1, 4, 6, 5, 8, 9, 9, 4, 2, 0, 3])
# array of unique values of x bigger than 1
a = np.unique(x[x>=2])
step = len(a)  # if you encounter memory problems, try a smaller step
result = []
for i in range(0, len(a), step):
    ai = a[i:i + step]
    c = np.argwhere(x >= ai[:, None])
    c[:,0] = ai[c[:,0]]
    c =  np.pad(c, ((1,1), (0,0)), 'symmetric')
    d = np.where(np.diff(c[:,1]) !=1)[0]
    e = as_strided(d, shape=(len(d)-1, 2), strides=d.strides*2).copy()
    # e = e[(np.diff(e, axis=1) > 1).flatten()]
    e[:,0] = e[:,0] + 1 
    result.append(np.hstack([c[:,0][e[:,0, None]], c[:,1][e]]))
result = np.concatenate(result)
# array([[ 2,  0,  2],
#        [ 2,  7, 14],
#        [ 2, 16, 16],
#        [ 3,  1,  2],
#        [ 3,  7, 13],
#        [ 3, 16, 16],
#        [ 4,  2,  2],
#        [ 4,  7, 13],
#        [ 5,  8, 12],
#        [ 6,  8,  8],
#        [ 6, 10, 12],
#        [ 8, 10, 12],
#        [ 9, 11, 12]])

很抱歉没有评论每一步的作用——如果以后我能抽出时间,我会修复它。

这确实是一个非常有趣的问题。我试图把它分成三部分来解决。

分组:

import numpy as np
import pandas as pd
x = np.array([2, 3, 4, 0, 0, 1, 1, 4, 6, 5, 8, 9, 9, 4, 2, 0, 3])
groups = pd.DataFrame(x).groupby([0]).indices

所以群是字典{0: [3, 4, 15], 1: [5, 6], 2: [0, 14], 3: [1, 16], 4: [2, 7, 13], 5: [9], 6: [8], 8: [10], 9: [11, 12]},它的值是dtype=int64numpy数组。

掩蔽:

在这一部分中,我对每个唯一值i的几个掩码数组x>=i按降序进行迭代:

mask_array = np.zeros(x.size).astype(int)
for group in list(groups)[::-1]:
    mask = mask_array[groups[group]] = 1
    # print(group, ':', mask_array)
    # output = find_slices(mask)

这些口罩看起来是这样的:

9 : [0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0]
8 : [0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0]
6 : [0 0 0 0 0 0 0 0 1 0 1 1 1 0 0 0 0]
5 : [0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0]
4 : [0 0 1 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
3 : [0 1 1 0 0 0 0 1 1 1 1 1 1 1 0 0 1]
2 : [1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 1]
1 : [1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 0 1]
0 : [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]

从掩码中提取切片

我希望构造一个名为find_slices的函数,从掩码数组中提取切片位置(如果取消注释的话(。这就是我所做的:

def find_slices(m):
    m1 = np.r_[0, m]
    m2 = np.r_[m, 0]
    starts, = np.where(~m1 & m2)
    ends, = np.where(m1 & ~m2)
    return np.c_[starts, ends - 1]

例如,阵列[0 1 1 0 0 0 0 1 1 1 1 1 1 1 0 0 1]的切片位置将是[[1, 2], [7, 13], [16, 16]]。请注意,这不是返回切片的标准方式,结束位置通常增加1。

最终脚本

毕竟,一个人需要一些策略来达到预期的输出,这里就像它在最后看起来一样:

import numpy as np
import pandas as pd
x = np.array([2, 3, 4, 0, 0, 1, 1, 4, 6, 5, 8, 9, 9, 4, 2, 0, 3])
groups = pd.DataFrame(x).groupby([0]).indices
mask_array = np.zeros(x.size).astype(bool)
m = []
for group in list(groups)[::-1]:
    mask_array[groups[group]] = True
    s = find_slices(mask_array)
    group_output = np.c_[np.repeat(group, s.shape[0]), s] #insert first column
    m.append(group_output) 
output = np.concatenate(m[::-1])
output = output[output[:,1]!= output[:,2]] #elimate slices with unit length

输出:

 [[ 0  0 16]
 [ 1  0  2]
 [ 1  5 14]
 [ 2  0  2]
 [ 2  7 14]
 [ 3  1  2]
 [ 3  7 13]
 [ 4  7 13]
 [ 5  8 12]
 [ 6 10 12]
 [ 8 10 12]
 [ 9 11 12]]

最新更新