Numpy - 索引多维数组的一维



>我有一个形状为(6, 2, 4)的numpy数组:

x = np.array([[[0, 3, 2, 0],
[1, 3, 1, 1]],
[[3, 2, 3, 3],
[0, 3, 2, 0]],
[[1, 0, 3, 1],
[3, 2, 3, 3]],
[[0, 3, 2, 0],
[1, 3, 2, 2]],
[[3, 0, 3, 1],
[1, 0, 1, 1]],
[[1, 3, 1, 1],
[3, 1, 3, 3]]])

我有这样的数组choices

choices = np.array([[1, 1, 1, 1],
[0, 1, 1, 0],
[1, 1, 1, 1],
[1, 0, 0, 0],
[1, 0, 1, 1],
[0, 0, 0, 1]])

如何使用choices数组仅索引大小为 2 的中间维度,并以最有效的方式获取形状(6, 4)的新 numpy 数组?

结果将是这样的:

[[1 3 1 1]
[3 3 2 3]
[3 2 3 3]
[1 3 2 0]
[1 0 1 1]
[1 3 1 3]]

我试图通过x[:, choices, :]做到这一点,但这并没有返回我想要的。我也尝试做x.take(choices, axis=1)但没有运气。

使用np.take_along_axis沿第二个轴进行索引 -

In [16]: np.take_along_axis(x,choices[:,None],axis=1)[:,0]
Out[16]: 
array([[1, 3, 1, 1],
[3, 3, 2, 3],
[3, 2, 3, 3],
[1, 3, 2, 0],
[1, 0, 1, 1],
[1, 3, 1, 3]])

或者使用显式integer-array索引 -

In [22]: m,n = choices.shape
In [23]: x[np.arange(m)[:,None],choices,np.arange(n)]
Out[23]: 
array([[1, 3, 1, 1],
[3, 3, 2, 3],
[3, 2, 3, 3],
[1, 3, 2, 0],
[1, 0, 1, 1],
[1, 3, 1, 3]])

因为我最近遇到了这个问题,发现@divakar的答案很有用,但仍然想要一个通用功能(与 dims 的数量等无关(,这里是:

def take_indices_along_axis(array, choices, choice_axis):
"""
array is N dim
choices are integer of N-1 dim
with valuesbetween 0 and array.shape[choice_axis] - 1
choice_axis is the axis along which you want to take indices
"""
nb_dims = len(array.shape)
list_indices = []
for this_axis, this_axis_size in enumerate(array.shape):
if this_axis == choice_axis:
# means this is the axis along which we want to choose
list_indices.append(choices)
continue
# else, we want arange(this_axis), but reshaped to match the purpose
this_indices = np.arange(this_axis_size)
reshape_target = [1 for _ in range(nb_dims)]
reshape_target[this_axis] = this_axis_size # replace the corresponding axis with the right range
del reshape_target[choice_axis] # remove the choice_axis
list_indices.append(
this_indices.reshape(tuple(reshape_target))
)
tuple_indices = tuple(list_indices)
return array[tuple_indices]
# test it !
array = np.random.random(size=(10, 10, 10, 10))
choices = np.random.randint(10, size=(10, 10, 10))
assert take_indices_along_axis(array, choices, choice_axis=0)[5, 5, 5] == array[choices[5, 5, 5], 5, 5, 5]
assert take_indices_along_axis(array, choices, choice_axis=2)[5, 5, 5] == array[5, 5, choices[5, 5, 5], 5]

相关内容

  • 没有找到相关文章

最新更新