numpy中axis选项背后的算法是什么



我无法理解numpy中的axis如何适用于任何通用n-D数组

例如,无论array的形状如何,np.mean(array, axis=1)都将给出期望的数组。

如果我想在不使用numpy的情况下编写代码,我可以为k-D数组制作平均函数,例如mean_1dmean_2d。。。带轴选项。然而,我无法想象如何制作一个适用于任何k值的输入数组的函数。到目前为止,我还找不到一个解释numpy轴是如何工作的。有人能帮我了解幕后发生了什么吗?

添加:我最初的目标是制作一个代码,使用numba进行具有异常值拒绝的快速阵列组合(即,n-D阵列列表的西格玛截断中值组合,并获得(n-1(-D阵列(。

如果数组的大小允许,您可以递归地查看它。我刚刚测试了下面的代码,它似乎可以很好地用于轴3上的sum函数。这很简单,但它应该传达这样的想法:

import numpy as np
a = np.ones((3,4,5,6,7))
# first get the shape of the result
def shape_res(sh1, axis) :
sh2 = [sh1[0]]
if len(sh1)>0 :
for i in range(len(sh1)) :
if i > 0 and i != axis :
sh2.append(sh1[i])
elif i == axis :
# shrink the summing axis
sh2.append(1)
return sh2

def sum_axis(a, axis=0) : 
cur_sum = 0
sh1 = np.shape(a)
sh2 = shape_res(sh1, axis)
res = np.zeros(sh2)
print(sh1, sh2)
def rec_fun (a, res, cur_axis) :
for i in range(len(a)) :
if axis > cur_axis :
# dig in the array to reach the right axis
next_axis = cur_axis + 1
rec_fun(a[i], res[i], next_axis)
elif axis == cur_axis :
# sum along the given axis
res[0] += a[i]
rec_fun(a, res, cur_axis)
return res
result = sum_axis(a, axis=3)
print(result)

最新更新