沿着特定的轴,按另一个数组对numpy数组进行排序



与此答案类似,我有一对3D numpy数组ab,我想根据a的值对b的条目进行排序。与这个答案不同,我只想沿着数组的一个轴进行排序。

我对numpy.argsort()文档的天真解读:

Returns
-------
index_array : ndarray, int
    Array of indices that sort `a` along the specified axis.
    In other words, ``a[index_array]`` yields a sorted `a`.

让我相信我可以用以下代码进行分类:

import numpy
a = numpy.zeros((3, 3, 3))
a += numpy.array((1, 3, 2)).reshape((3, 1, 1))
print "a"
print a
"""
[[[ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]]
 [[ 3.  3.  3.]
  [ 3.  3.  3.]
  [ 3.  3.  3.]]
 [[ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]]]
"""
b = numpy.arange(3*3*3).reshape((3, 3, 3))
print "b"
print b
"""
[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]
 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]
 [[18 19 20]
  [21 22 23]
  [24 25 26]]]
"""
print "a, sorted"
print numpy.sort(a, axis=0)
"""
[[[ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]]
 [[ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]]
 [[ 3.  3.  3.]
  [ 3.  3.  3.]
  [ 3.  3.  3.]]]
"""
##This isnt' working how I'd like
sort_indices = numpy.argsort(a, axis=0)
c = b[sort_indices]
"""
Desired output:
[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]
 [[18 19 20]
  [21 22 23]
  [24 25 26]]
 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]]
"""
print "Desired shape of b[sort_indices]: (3, 3, 3)."
print "Actual shape of b[sort_indices]:"
print c.shape
"""
(3, 3, 3, 3, 3)
"""

做这件事的正确方法是什么?

您仍然需要为其他两个维度提供索引才能正常工作。

>>> a = numpy.zeros((3, 3, 3))
>>> a += numpy.array((1, 3, 2)).reshape((3, 1, 1))
>>> b = numpy.arange(3*3*3).reshape((3, 3, 3))
>>> sort_indices = numpy.argsort(a, axis=0)
>>> static_indices = numpy.indices((3, 3, 3))
>>> b[sort_indices, static_indices[1], static_indices[2]]
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],
       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]],
       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])

numpy.indices计算通过其他两个轴(或n-1个轴,其中n=轴的总数)"展平"时阵列的每个轴的索引。换句话说,这(为长帖道歉):

>>> static_indices
array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
        [[1, 1, 1],
         [1, 1, 1],
         [1, 1, 1]],
        [[2, 2, 2],
         [2, 2, 2],
         [2, 2, 2]]],

       [[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],
        [[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],
        [[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]]],

       [[[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]],
        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]],
        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]]])

这些是每个轴的同一性指数;当用于索引b时,它们重新创建b。

>>> b[static_indices[0], static_indices[1], static_indices[2]]
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],
       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],
       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])

作为numpy.indices的替代方案,您可以使用numpy.ogrid,正如Undepu所建议的那样。由于ogrid生成的对象较小,为了保持一致性,我将创建所有三个轴,但请注意Undepu的注释,以获得只生成两个轴的方法。

>>> static_indices = numpy.ogrid[0:a.shape[0], 0:a.shape[1], 0:a.shape[2]]
>>> a[sort_indices, static_indices[1], static_indices[2]]
array([[[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],
       [[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]],
       [[ 3.,  3.,  3.],
        [ 3.,  3.,  3.],
        [ 3.,  3.,  3.]]])

numpy.take_along_axis()可以做到这一点,而且可能不会占用不必要的额外内存:

# From the original question:
import numpy
a = numpy.zeros((3, 3, 3))
a += numpy.array((1, 3, 2)).reshape((3, 1, 1))
b = numpy.arange(3*3*3).reshape((3, 3, 3))
sort_indices = numpy.argsort(a, axis=0)
# This is not working as expected:
c = b[sort_indices]
# This does what is expected:
c = numpy.take_along_axis(b, sort_indices, axis=0)
print(c)
"""
[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]
 [[18 19 20]
  [21 22 23]
  [24 25 26]]
 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]]
"""

相关内容

  • 没有找到相关文章