矢量化方式使用高级索引收缩Numpy数组



我有一个维度(d1,d2,d3,d4(的Numpy数组,例如A = np.arange(120).reshape((2,3,4,5))。我想把它承包,以便获得尺寸(d1,d2,d4(的B。要拾取的部件的d3个索引被收集在维度(d1,d2(的索引阵列Idx中。Idx为沿着(d1,d2(的索引的每对(x1,x2(提供索引x3,对于该索引,B应当在A中保留整个对应的d4行,例如Idx = rng.integers(4, size=(2,3))

总之,对于所有(x1,x2(,我想要B[x1,x2,:] = A[x1,x2,Idx[x1,x2],:]

有没有一种有效的、矢量化的方法可以做到这一点,而不使用循环?我知道这类似于在Python中使用高级索引进行nd数组收缩的Easy方法,但我很难将解决方案扩展到更高维的数组。

MWE

A = np.arange(120).reshape((2,3,4,5))
Idx = rng.integers(4, size=(2,3))
# correct result:
B = np.zeros((2,3,5))
for i in range(2):
for j in range(3):
B[i,j,:] = A[i,j,Idx[i,j],:]
# what I would like, which doesn't work:
B = A[:,:,Idx[:,:],:]
一种方法是np.squeeze(np.take_along_axis(A, Idx[:,:,None,None], axis=2), axis=2)

例如,

In [49]: A = np.arange(120).reshape(2, 3, 4, 5)
In [50]: rng = np.random.default_rng(0xeeeeeeeeeee)
In [51]: Idx = rng.integers(4, size=(2,3))
In [52]: Idx
Out[52]: 
array([[2, 0, 1],
[0, 2, 1]])
In [53]: C = np.squeeze(np.take_along_axis(A, Idx[:,:,None,None], axis=2), axis=2)
In [54]: C
Out[54]: 
array([[[ 10,  11,  12,  13,  14],
[ 20,  21,  22,  23,  24],
[ 45,  46,  47,  48,  49]],
[[ 60,  61,  62,  63,  64],
[ 90,  91,  92,  93,  94],
[105, 106, 107, 108, 109]]])

检查已知的正确结果:

In [55]: # correct result:
...: B = np.zeros((2,3,5))
...: for i in range(2):
...:     for j in range(3):
...:         B[i,j,:] = A[i,j,Idx[i,j],:]
...: 
In [56]: B
Out[56]: 
array([[[ 10.,  11.,  12.,  13.,  14.],
[ 20.,  21.,  22.,  23.,  24.],
[ 45.,  46.,  47.,  48.,  49.]],
[[ 60.,  61.,  62.,  63.,  64.],
[ 90.,  91.,  92.,  93.,  94.],
[105., 106., 107., 108., 109.]]])

三种备选方案的时间:

In [91]: %%timeit
...: B = np.zeros((2,3,5),A.dtype)
...: for i in range(2):
...:     for j in range(3):
...:         B[i,j,:] = A[i,j,Idx[i,j],:]
...: 
11 µs ± 48.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [92]: timeit A[np.arange(2)[:,None],np.arange(3),Idx]
8.58 µs ± 44 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [94]: timeit np.squeeze(np.take_along_axis(A, Idx[:,:,None,None], axis=2), axis=2)
29.4 µs ± 448 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

较大阵列的相对时间可能不同。但这是一个测试正确性的好尺寸。

最新更新