沿2个轴的numpy`take'



i具有数据的3D数组a和索引的2D数组b。我需要使用b的索引沿3轴沿第三轴进行a的子阵列。我可以用take这样做:

a = np.arange(24).reshape((2,3,4))
b = np.array([0,2,1,3]).reshape((2,2))
np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])

我可以使用一些花哨的索引而无需列表理解吗?我担心效率,因此,如果在这种情况下,花哨的索引效率不高,我想知道。

编辑我尝试过的第一件事是a[[0,1],:,b],但它没有给出子阵列,我需要

In [317]: a
Out[317]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
In [318]: a = np.arange(24).reshape((2,3,4))
     ...: b = np.array([0,2,1,3]).reshape((2,2))
     ...: np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])
     ...: 
Out[318]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],
       [[13, 15],
        [17, 19],
        [21, 23]]])

所以您想要0&来自第一个块的2列,1&3从第二个。

制作一个与b相匹配的c,并体现此观察

In [319]: c=np.array([[0,0],[1,1]])
In [320]: c
Out[320]: 
array([[0, 0],
       [1, 1]])
In [321]: b
Out[321]: 
array([[0, 2],
       [1, 3]])
In [322]: a[c,:,b]
Out[322]: 
array([[[ 0,  4,  8],
        [ 2,  6, 10]],
       [[13, 17, 21],
        [15, 19, 23]]])

那是正确的数字,但不是正确的形状。

可以使用列向量而不是c

In [323]: a[np.arange(2)[:,None],:,b]  # or a[[[0],[1]],:,b]
Out[323]: 
array([[[ 0,  4,  8],
        [ 2,  6, 10]],
       [[13, 17, 21],
        [15, 19, 23]]])

至于形状,我们可以转置最后两个轴

In [324]: a[np.arange(2)[:,None],:,b].transpose(0,2,1)
Out[324]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],
       [[13, 15],
        [17, 19],
        [21, 23]]])

需要此转置,因为我们在两个索引阵列之间有一个切片,即基本和高级索引的混合。它已记录在记录中,但从来没有那么经常令人困惑。它最后放置了切片尺寸(3(,我们必须将其转回。

漂亮的小索引难题!

此高级/基本转置的最新问题和解释:

索引数字多维阵列取决于切片方法

这是我的第一次尝试。我会看看是否可以做得更好。

#using numpy broadcasting.
np.r_[a[0][:,b[0]],a[1][:,b[1]]].reshape(2,3,2)
Out[300]: In [301]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],
       [[13, 15],
        [17, 19],
        [21, 23]]])

第二次尝试:

#convert both a and b to a 2d array and then slice all rows and only columns determined by b.
a.reshape(6,4)[np.arange(6)[:,None],b.repeat(3,0)].reshape(2,3,2)
Out[429]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],
       [[13, 15],
        [17, 19],
        [21, 23]]])

最新更新