假设我有一个形状为(66,5(的NumPy数组A
和形状为(100,66,5(的B
。
A
的元素将对B
的第一维度(axis=0
(进行索引,其中值从0到99(即B
的第一维度为100(。
A =
array([[ 1, 0, 0, 1, 0],
[ 0, 2, 0, 2, 4],
[ 1, 7, 0, 5, 5],
[ 2, 1, 0, 1, 7],
[ 0, 7, 0, 1, 4],
[ 0, 0, 3, 6, 0]
.... ]])
例如,A[4,1]将采用B
的第一维度的索引7、B
的第二维度的索引4和第三维度B
的索引1。
我想要的是生成形状为(66,5(的阵列C
,其中它包含基于A
中的元素选择的B
中的元素。
您可以使用np.take_along_axis
来做到这一点:
import numpy as np
np.random.seed(0)
a = np.random.randint(100, size=(66, 5))
b = np.random.random(size=(100, 66, 5))
c = np.take_along_axis(b, a[np.newaxis], axis=0)[0]
# Test some element
print(c[25, 3] == b[a[25, 3], 25, 3])
# True
如果我理解正确,您正在寻找B
第一维度的高级索引。您可以使用np.indices
创建B
的其他两个维度所需的索引,并使用高级索引:
idx = np.indices(A.shape)
C = B[A,idx[0],idx[1]]
示例:
B = np.random.rand(10,20,30)
A = np.array([[ 1, 0, 0, 1, 0],
[ 0, 2, 0, 2, 4],
[ 1, 7, 0, 5, 5],
[ 2, 1, 0, 1, 7],
[ 0, 7, 0, 1, 4],
[ 0, 0, 3, 6, 0]])
print(C[4,1]==B[7,4,1])
#True
使用以下(使用NumPy
库的函数(:
print(A)
# array([[2, 0],
# [1, 1],
# [2, 0]])
print(B)
# array([[[ 5, 7],
# [ 0, 0],
# [ 0, 0]],
# [[ 1, 8],
# [ 1, 9],
# [10, 1]],
# [[12, 22],
# [ 2, 2],
# [ 2, 2]]])
temp = A.reshape(-1) + np.cumsum(np.ones([A.reshape(-1).shape[0]])*B.shape[0], dtype = 'int') - 3
C = B.swapaxes(0, 1).swapaxes(2, 1).reshape(-1)[temp].reshape(A.shape)
print(C)
# array([[12, 7],
# [ 1, 9],
# [ 2, 0]])