选取具有1D索引数组的3D numpy数组



假设有2个Numpy数组:形状为(100,10,2(的3d_array,带有形状(100(的1d_indices

什么是Numpy方法/等效方法:

result = []
for i,j in zip(range(len(3d_array)),1d_indices):
result.append(3d_array[i,j])

应该返回结果。shape(100.2(

我最接近的是在Numpy上使用花式索引:

result = 3d_array[np.arange(len(3d_array)), 1d_indices]

您的代码片段应该等效于3d_array[:, 1d_indices].reshape(-1,2),例如:

a = np.arange(100*10*2).reshape(100,10,2) # 3d array
b = np.random.randint(0, 10, 100) # 1d indices
def fun(a,b):
result = []
for i in range(len(a)):
for j in b:
result.append(a[i,j])
return np.array(result)
assert (a[:, b].reshape(-1, 2) == fun(a, b)).all()

最新更新