花哨的索引 numpy ndarray



>假设我有一个形状为a的数组:

import numpy as np
n = 10
d = 5
a = np.zeros(shape = np.repeat(n,d))

并且我想获得对应于沿维度:的索引(0,...,:,...,0)的值,从而产生一个(n,d)形数组bb[i,j] = a[0,...,0,i,0,...,0]i位于第j维。

如何从a中提取b

获取矢量化解决方案的平展索引和索引 -

n = len(a)
d = a.ndim
idxs = np.multiply.outer(n**np.arange(d), np.arange(n))
out = a.flat[idxs]

最简单的方法是做一个for循环:

# get the first slice of `a` along given dimension `j`
def get_slice(a,j):
idx = [0]*len(a.shape)
idx[j] = slice(None)
return a[tuple(idx)]
out = np.stack([get_slice(a,j) for j in range(len(a.shape))])

out.shape(10,5)