如何在Python中基于2D数组对NumPy 3D数组进行索引



假设我有一个形状为(66,5(的NumPy数组A和形状为(100,66,5(的B

A的元素将对B的第一维度(axis=0(进行索引,其中值从0到99(即B的第一维度为100(。

A = 
array([[   1,    0,    0,    1,    0],
[   0,    2,    0,    2,    4],
[   1,    7,    0,    5,    5],
[   2,    1,    0,    1,    7],
[   0,    7,    0,    1,    4],
[   0,    0,    3,    6,    0]
....                         ]])

例如,A[4,1]将采用B的第一维度的索引7、B的第二维度的索引4和第三维度B的索引1。

我想要的是生成形状为(66,5(的阵列C,其中它包含基于A中的元素选择的B中的元素。

您可以使用np.take_along_axis来做到这一点:

import numpy as np
np.random.seed(0)
a = np.random.randint(100, size=(66, 5))
b = np.random.random(size=(100, 66, 5))
c = np.take_along_axis(b, a[np.newaxis], axis=0)[0]
# Test some element
print(c[25, 3] == b[a[25, 3], 25, 3])
# True

如果我理解正确,您正在寻找B第一维度的高级索引。您可以使用np.indices创建B的其他两个维度所需的索引,并使用高级索引:

idx = np.indices(A.shape)
C = B[A,idx[0],idx[1]]

示例:

B = np.random.rand(10,20,30)
A = np.array([[   1,    0,    0,    1,    0],
[   0,    2,    0,    2,    4],
[   1,    7,    0,    5,    5],
[   2,    1,    0,    1,    7],
[   0,    7,    0,    1,    4],
[   0,    0,    3,    6,    0]])
print(C[4,1]==B[7,4,1])
#True

使用以下(使用NumPy库的函数(:

print(A)
# array([[2, 0],
#        [1, 1],
#        [2, 0]])
print(B)
# array([[[ 5,  7],
#         [ 0,  0],
#         [ 0,  0]],
#        [[ 1,  8],
#         [ 1,  9],
#         [10,  1]],
#        [[12, 22],
#         [ 2,  2],
#         [ 2,  2]]])
temp = A.reshape(-1) + np.cumsum(np.ones([A.reshape(-1).shape[0]])*B.shape[0], dtype = 'int') - 3
C = B.swapaxes(0, 1).swapaxes(2, 1).reshape(-1)[temp].reshape(A.shape)
print(C)
# array([[12,  7],
#        [ 1,  9],
#        [ 2,  0]])

最新更新