如何快速获取numpy数组中非零值的索引



现在我正在编写一个函数,该函数将通过以下规则获得非零值的索引:

  1. 期望的结果是一个列表。每个元素表示非零值的连续切片的索引。所以对于[0,0,0,1,1,1,0,1,1,0]的列表,它应该得到列表[[3,4,5], [7,8]]
  2. 列表中不同值的索引应在分隔列表中,即对于[0,0,1,1,1,2,2,1,1,0]列表,预期结果为[[2,3,4],[5,6],[7,8]]

你知道吗?提前感谢!

arr作为输入数组,并将数组列表作为输出,您可以这样做-

# Store non-zero element indices
idx = np.where(arr)[0]
# Get indices where the shifts occur, i.e. positions where groups of identical 
# elements are separated. For this we perform differnetiation and look for 
# non-zero values and then get those positions. Finally, add 1 to compensate 
# for differentiation that would have decreased those shift indices by 1.
shift_idx = np.where(np.diff(arr[idx])!=0)[0]+1
# Split the non-zero indices at those shifts for final output
out = np.split(idx,shift_idx)

输入、输出示例-

In [35]: arr
Out[35]: array([0, 0, 1, 1, 1, 2, 2, 1, 1, 0, 2, 2, 4, 3, 3, 3, 0])
In [36]: out
Out[36]: 
[array([2, 3, 4]),
 array([5, 6]),
 array([7, 8]),
 array([10, 11]),
 array([12]),
 array([13, 14, 15])]

最新更新