基于存储在另一个数组或列表中的索引拆分numpy多维数组



我有一个形状为(12,2,3,3(的numpy多维数组

import numpy as np
arr = np.arange(12*2*3*3).reshape((12,2,3,3))

我需要根据第二维度来选择那些元素,在第二维度中,数据块存储在另一个列表中

indices = [0,1,0,0,1,1,0,1,1,0,1,1]

在一个阵列中,其余的在另一阵列中。在任何一种情况下的输出都应该是另一个形状为(12,3,3(的阵列

arr2 = np.empty((arr.shape[0],*arr.shape[-2:]))

我可以使用for loop

for i, ii in enumerate(indices):
arr2[i] = arr[i, indices[ii],...]

然而,我正在寻找一句俏皮话。

当我尝试使用列表作为索引进行索引时

test = arr[:,indices,...]

我得到了形状为(12,12,3,3(的test,而不是(12,3,3(。你能帮我吗?

您可以使用np.arange对第一个维度进行索引:

test = arr[np.arange(arr.shape[0]),indices,...]

或者只是pythonrange函数:

test = arr[range(arr.shape[0]),indices,...]

最新更新