相当于Matlab shiftdim()的Python



我目前正在将一些Matlab代码转换为Python,我想知道是否有类似于Matlab的shiftdim(A, n)的函数

B=shiftdim(A,n(将数组A的维度移动n个位置。当n是正整数时,shiftdim将维度向左移动,当n是负整数时,将维度向右移动。例如,如果A是2乘3乘4的阵列,则shiftdim(A,2(返回4乘2乘3的阵列。

如果使用numpy,则可以使用np.moveaxis

来自文档:

>>> x = np.zeros((3, 4, 5))
>>> np.moveaxis(x, 0, -1).shape
(4, 5, 3)
>>> np.moveaxis(x, -1, 0).shape
(5, 3, 4)

numpy.moveaxis(a,source,destination([source]

Parameters
a:        np.ndarray 
The array whose axes should be reordered.
source:   int or sequence of int
Original positions of the axes to move. These must be unique.
destination: int or sequence of int
Destination positions for each of the original axes. 
These must also be unique.

shiftdim的函数比移动轴要复杂一些。

  • 对于输入shiftdim(A, n),如果n为正,则将轴向左移动n(即旋转(,但如果n为负,则将轴线向右移动并附加大小为1的尾部维度
  • 对于输入shiftdim(A),删除任何大小为1的尾部维度
from collections import deque
import numpy as np
def shiftdim(array, n=None):
if n is not None:
if n >= 0:
axes = tuple(range(len(array.shape)))
new_axes = deque(axes)
new_axes.rotate(n)
return np.moveaxis(array, axes, tuple(new_axes))
return np.expand_dims(array, axis=tuple(range(-n)))
else:
idx = 0
for dim in array.shape:
if dim == 1:
idx += 1
else:
break
axes = tuple(range(idx))
# Note that this returns a tuple of 2 results
return np.squeeze(array, axis=axes), len(axes)

与Matlab文档相同的示例

a = np.random.uniform(size=(4, 2, 3, 5))
print(shiftdim(a, 2).shape)      # prints (3, 5, 4, 2)
print(shiftdim(a, -2).shape)     # prints (1, 1, 4, 2, 3, 5)
a = np.random.uniform(size=(1, 1, 3, 2, 4))
b, nshifts = shiftdim(a)
print(nshifts)                   # prints 2
print(b.shape)                   # prints (3, 2, 4)

最新更新