我有一个多维numpy数组,如下所示:
[
[
[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15]
],
[
[16,17,18,19,20],
[21,22,23,24,25],
[26,27,28,29,30]
]
]
并希望创建一个函数,将其沿指定的轴切成两半,在大小不均匀的情况下不包括中间元素。所以如果我说my_function(my_ndarray, 0)
,我想得到
[
[
[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15]
]
]
对于my_function(my_ndarray, 1)
,我想获得
[
[
[1,2,3,4,5]
],
[
[16,17,18,19,20]
]
]
对于my_function(my_ndarray, 2)
,我想获得
[
[
[1,2],
[6,7],
[11,12]
],
[
[16,17],
[21,22],
[26,27]
]
]
我的第一次尝试涉及np.split((方法,但不幸的是,当轴的长度是一个奇数时,它会遇到问题,并且不允许我指定要省略的内容。理论上,如果是这样的话,我可以做一个if语句,并切掉所选轴的最后一片,但我想知道是否有更有效的方法来解决这个问题。
给定一个轴axis
和一个数组a
,我认为你可以进行
def my_function(a, axis):
l = a.shape[axis]//2
return a.take(range(l), axis=axis)
示例:
>>> my_function(a, 0)
array([[[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]]])
>>> my_function(a, 1)
array([[[ 1, 2, 3, 4, 5]],
[[16, 17, 18, 19, 20]]])
>>> my_function(a, 2)
array([[[ 1, 2],
[ 6, 7],
[11, 12]],
[[16, 17],
[21, 22],
[26, 27]]])
关于:
def slice_n(a, n):
slices = [slice(None)]*a.ndim
slices[n] = slice(0, a.shape[n]//2)
return a[tuple(slices)]
slice_n(a, 0)
array([[[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]]])
slice_n(a, 1)
array([[[ 1, 2, 3, 4, 5]],
[[16, 17, 18, 19, 20]]])
slice_n(a, 2)
array([[[ 1, 2],
[ 6, 7],
[11, 12]],
[[16, 17],
[21, 22],
[26, 27]]])
使用的输入(a
(:
array([[[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]],
[[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25],
[26, 27, 28, 29, 30]]])