我有以下代码行:
idxs = [i for i,x in enumerate(labels) if x==lbl]
- labels是int的numpy数组
- lbl是一个int
idxs=索引s.t.标签的相应元素具有值lbl
问题:有较短的一班吗?
谢谢!
您可以使用numpy.where:的单参数形式
idxs = np.where(labels == lbl)[0]
或者,等效地,使用numpy.nonzero:
idxs = np.nonzero(labels == lbl)[0]
或者,为了更好的可读性(谢谢,乔!),
idxs = np.flatnonzero(labels == lbl)
例如,
In [332]: np.random.seed(1)
In [333]: labels = np.random.randint(5, size=10)
In [334]: labels
Out[334]: array([3, 4, 0, 1, 3, 0, 0, 1, 4, 4])
In [335]: [i for i,x in enumerate(labels) if x==lbl]
Out[335]: [3, 7]
In [336]: np.where(labels == lbl)[0]
Out[336]: array([3, 7])
使用np.where
比大型数组的列表理解快得多:
In [339]: labels = np.tile(labels, 1000)
In [340]: labels.shape
Out[340]: (10000,)
In [341]: %timeit np.where(labels == lbl)[0]
10000 loops, best of 3: 45.9 µs per loop
In [342]: %timeit [i for i,x in enumerate(labels) if x==lbl]
100 loops, best of 3: 5.31 ms per loop
In [343]: 5310/45.9
Out[343]: 115.68627450980392
我没有代表来评论答案。。。不过,请记住,在使用numpy.where
时,"lables"必须是一个numpy数组。
代码提升Undebu的答案:
idxs = np.where(np.array(labels) == lbl)[0]
只是想说清楚:正确的答案是Undebu做出的。