我想以矩阵操作的方式而不是使用 for 循环来实现以下代码。
a = np.random.randint(0, 7, (4,3))
b = np.random.randint(0, 6, (4,3,2))
c = None
for idx in xrange(a.shape[0]):
max_idx = np.argmax(a[idx])
ex_b = b[idx, max_idx].reshape(1, -1)
if c is None:
c = ex_b
else:
c = np.concatenate((c, ex_b), axis=0)
基本上,我想首先获得 a 的第二个维度中最大值的索引。然后我想根据这些指数提取 b 中相应的三维值。
例如:
a:
array([[5, 4, 1],
[3, 1, 3],
[4, 1, 2],
[0, 0, 5]])
b:
array([[[1, 3], [1, 4], [5, 0]],
[[2, 4], [2, 2], [1, 2]],
[[2, 1], [1, 2], [4, 5]],
[[4, 0], [5, 5], [0, 2]]])
那么np.argmax(a, axis=1)
会给array([0, 0, 0, 2])
所以c[0] = b[0][0], c[1]=b[1]b[0], c[2]=b[2][0], c[3]=b[3][2]
我认为这个 for 循环会减慢运行速度,有没有更优雅的方法可以更快地实现这一目标?
你可以像这样使用花哨的索引:
b[np.arange(4), np.argmax(a, axis=-1)]
# array([[1, 3],
# [2, 4],
# [2, 1],
# [0, 2]])