从numpy阵列中删除除最后一个非零序列以外的所有序列



问题

i具有一个一维的numpy阵列,主要用零填充,但也包含一些非零值组。

>> import numpy as np
>> a = np.zeros(10)
>> a[2:4] = 2
>> a[6:9] = 3
>> print a
[ 0.  0.  2.  2.  0.  0.  3.  3.  3.  0.]

我想获得仅包含最后一个非零组的数组。换句话说,除最后一个非零组以外的所有组都应由零替换。(这些组只能是1个元素)。喜欢:

[ 0.  0.  0.  0.  0.  0.  3.  3.  3.  0.]

非固体解决方案

这似乎可以解决问题。反向数组,找到第一个索引,其中元素之间的变化为负。然后用零替换所有后续元素。然后翻转。有点漫长:

>> b = a[::-1]
>> b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0
>> c = b[::-1]
>> print c
[ 0.  0.  0.  0.  0.  0.  3.  3.  3.  0.]

特定情况失败

但是,它在以下情况下不强大,并且在以下情况下失败(因为Where命令返回一个空索引列表):

>> a = np.zeros(10)
>> a[0:4] = 2
>> print a
[ 2.  2.  2.  2.  0.  0.  0.  0.  0.  0.]
>> b = a[::-1]
>> b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0
>> c = b[::-1]
>> print c
Traceback (most recent call last):
  File "<ipython-input-81-8cba57558ba8>", line 1, in <module>
    runfile('C:/Users/name/test1.py', wdir='C:/Users/name')
  File "C:ProgramDataAnaconda2libsite-packagesspyderutilssitesitecustomize.py", line 866, in runfile
    execfile(filename, namespace)
  File "C:ProgramDataAnaconda2libsite-packagesspyderutilssitesitecustomize.py", line 87, in execfile
    exec(compile(scripttext, filename, 'exec'), glob, loc)
  File "C:/Users/name/test1.py", line 21, in <module>
    b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0
IndexError: index 0 is out of bounds for axis 0 with size 0

fix

所以我需要介绍一个if子句:

>> b = a[::-1]
>> if len(np.where(np.ediff1d(b) < 0)[0]) > 0:
>>     b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0
>> c = b[::-1]
>> print c
[ 2.  2.  2.  2.  0.  0.  0.  0.  0.  0.]

有一种更优雅的方法吗?

更新从Divakar的出色答案和MTRW的问题开始,我想扩展规格。如果输入阵列的非零值对分组内变化的非零数字组为负,则该方法也应起作用。

例如。np.array([1, 0, 0, 4, 5, 4, 5, 0, 0])

这意味着我们检查元素之间有正差或负差的方法,以找到组边界,无法很好地工作。

方法#1

由于我们是优雅的,让我们自己 feed 自己是一个单线 -

a[:(a[1:] > a[:-1]).cumsum().argmax()] = 0

样本运行 -

In [605]: a
Out[605]: array([ 0.,  0.,  2.,  2.,  0.,  0.,  3.,  3.,  3.,  0.])
In [606]: a[:(a[1:] > a[:-1]).cumsum().argmax()] = 0
In [607]: a
Out[607]: array([ 0.,  0.,  0.,  0.,  0.,  0.,  3.,  3.,  3.,  0.])

方法#2

上面的方法假设最后一个组号大于 0 s。如果不是这种情况,并且对于非零组组可能具有不同数字的情况

mask = a != 0
a[:(mask[1:] > mask[:-1]).cumsum().argmax()] = 0

样本运行 -

In [667]: a
Out[667]: array([-1,  0,  0, -4, -5,  4, -5,  0,  0])
In [668]: mask = a != 0
In [669]: a[:(mask[1:] > mask[:-1]).cumsum().argmax()] = 0
In [670]: a
Out[670]: array([ 0,  0,  0, -4, -5,  4, -5,  0,  0])

最新更新