我正在寻求通过numpy数组过滤的帮助。我目前有一个numpy数组,其中包含以下信息:
[[x1_1, x1_2, ..., x1_n], [x2_1, x2_2, ..., x2_n], [y1, y2, ..., yn]
ie。数组本质上是一个数据集,其中x1, x2是特征(坐标),y是输出(值)。每个数据点都有相应的x1, x2和y,因此,例如,数据点i对应的信息为x1_i, x2_i和yi。
现在,我想通过过滤y来提取所有的数据点,这意味着我想知道所有y>某个值的数据点。在我的例子中,我想要y> 0的所有情况的信息(仍然具有相同的numpy结构)。我真的不知道如何做到这一点-我一直在玩布尔索引,如d[0:2,y>0]
或d[d[2]>0]
,但还没有得到任何地方。
一个澄清的例子:
给定数据集:
d = [[0.1, 0.2, 0.3], [-0.1,-0.2,-0.3], [1,1,-1]]
我拉所有的点或实例,其中y > 0
,即。d[2] > 0
,它应该返回值:
[[0.1, 0.2],[-0.1,-0.2],[1,1]]
您可以使用:
import numpy as np
d = np.array([[0.1, 0.2, 0.3], [-0.1,-0.2,-0.3], [1,1,-1]])
print (d)
[[ 0.1 0.2 0.3]
[-0.1 -0.2 -0.3]
[ 1. 1. -1. ]]
#select last row by d[-1]
print (d[-1]>0)
[ True True False]
print (d[:,d[-1]>0])
[[ 0.1 0.2]
[-0.1 -0.2]
[ 1. 1. ]]