删除3D numpy数组中包含一定值的所有行

我有一个3D numpy数组

position = 
[[[ 12.86420681 825.87040876   1.           8.           0.        ]
[753.26000819 280.1334669    2.           8.           1.        ]
[ 51.6851021  330.65314794   3.           8.           0.        ]
[661.07157006  78.15962738   4.           8.           1.        ]
[878.59383346 550.5236096    5.           8.           1.        ]
[774.49249941 942.74557677   6.           8.           1.        ]
[301.20619756 206.50737851   7.           8.           1.        ]
[240.50228642  91.21979947   8.           8.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]]



positionNew = np.delete(position, np.where(position[0] == 0), axis=1)


positionNew =
[[[753.26000819 280.1334669    2.           8.           1.        ]
[661.07157006  78.15962738   4.           8.           1.        ]
[774.49249941 942.74557677   6.           8.           1.        ]
[301.20619756 206.50737851   7.           8.           1.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]

[[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]
[  0.           0.           0.           0.           0.        ]]]


[878.59383346 550.5236096    5.           8.           1.        ]



import numpy as np
d1 = [[1,2,3,0],
d2 = [[0,0,0,0],
d3 = [[0,0,0,0],
a = np.array([d1, d2, d3])


# for d1 (0), find where all rows (:) have a 1 in the last column (-1)
mask = a[0, :, -1] == 1
# for all of d1, d2, d3 (:), index the rows with the mask
a = a[:, mask]


[False  True False  True]
[[[4 5 6 1]
[2 4 6 1]]
[[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]]]
