numpy argmax with groupby



假设我有标签将数据分成几组。现在我想找到每个簇的原始数组的最大w.r.t的索引。例如:

import numpy as np
labels = np.array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0])
y = np.random.randint(0, 9, len(labels)) #array([6, 7, 5, 4, 2, 8, 4, 4, 5, 6, 4])

我想要[1,5],因为对于集群1,索引1处的最大值是7,对于集群0,索引5处的最大值是8。有没有可能不使用for循环?

我的朴素解供参考:

out = []
for i in [0, 1]:
temp = y.copy()
temp[labels == i] = -1
out.append(np.argmax(temp))

我相信这应该能行:

labels = np.array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0])
y = np.array([6, 7, 5, 4, 2, 8, 4, 4, 5, 6, 4])
arr = np.unique(labels)
result = np.ma.masked_where(labels[:,None] == arr, np.tile(y,(arr.shape[0],1)).T).argmax(axis=0)

主要思想是我们通过添加一个新维度来强制1D数组广播。然后,我们还必须为第二个数组创建相同的维度,但这使我们可以在不循环的情况下比较任意多的值。

最新更新