删除3D numpy数组中包含一定值的所有行



我有一个3D numpy数组

position = 
[[[ 12.86420681 825.87040876   1.           8.           0.        ]
[753.26000819 280.1334669    2.           8.           1.        ]
[ 51.6851021  330.65314794   3.           8.           0.        ]
[661.07157006  78.15962738   4.           8.           1.        ]
[878.59383346 550.5236096    5.           8.           1.        ]
[774.49249941 942.74557677   6.           8.           1.        ]
[301.20619756 206.50737851   7.           8.           1.        ]
[240.50228642  91.21979947   8.           8.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]]

在数组的第一页上,它有一些随机值,所有其他页面都是零,我只对第一页的第4列感兴趣,它有1或0。我想删除所有页面的所有行,如果第一页上的行值为&;0&;

我试图通过给出以下代码来解决这个问题:

positionNew = np.delete(position, np.where(position[0] == 0), axis=1)

但是我得到了这个作为输出:

positionNew =
[[[753.26000819 280.1334669    2.           8.           1.        ]
[661.07157006  78.15962738   4.           8.           1.        ]
[774.49249941 942.74557677   6.           8.           1.        ]
[301.20619756 206.50737851   7.           8.           1.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]]

但是我缺少了其中一行,它的值为"1",也就是这个:

[878.59383346 550.5236096    5.           8.           1.        ]

我哪里做错了?

我不能完全重建您的数据(请下次分享代码!),但这是一个类似的例子:

import numpy as np
d1 = [[1,2,3,0],
[4,5,6,1],
[7,8,9,0],
[2,4,6,1]]
d2 = [[0,0,0,0],
[0,0,0,0],
[0,0,0,0],
[0,0,0,0],]
d3 = [[0,0,0,0],
[0,0,0,0],
[0,0,0,0],
[0,0,0,0],]
a = np.array([d1, d2, d3])

你可以使用布尔索引来选择你想要的数据:

# for d1 (0), find where all rows (:) have a 1 in the last column (-1)
mask = a[0, :, -1] == 1
print(mask)
# for all of d1, d2, d3 (:), index the rows with the mask
a = a[:, mask]
print(a)

输出:

[False  True False  True]
[[[4 5 6 1]
[2 4 6 1]]
[[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]]]

最新更新