Numpy抱怨模棱两可的数组:值错误:的真值



我在Python 3中有一个最小的代码,它使用numpy和函数apply_along_axis。我不明白我遇到此错误的原因:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

lambda内提供直接公式是有效的。一旦我使用另一个函数,我就会收到此错误。我应该归还其他东西吗?


最小代码:

import numpy as np
def logn(x, b):
return np.log(x)/np.log(b)
def h(x, b):
if x == 0:
return 0
else:
return -x*logn(x, b)
p = np.array([0.00000000e+00, 9.99997956e-01, 2.04440466e-06])
print(np.apply_along_axis(lambda _e: h(_e, 3), -1, p))

看看apply_along_axis传递给函数的内容:

In [99]: def foo(x): 
...:     print(x) 
...:     return x 
...:                                                                                 
In [100]: np.apply_along_axis(foo, -1, p)                                                
[0.00000000e+00 9.99997956e-01 2.04440466e-06]
Out[100]: array([0.00000000e+00, 9.99997956e-01, 2.04440466e-06])

在一维数组的情况下,它会立即传递整个数组。 它不会在该维度上进行迭代。 这就是apply_along_axis的全部目的 - 将一维数组传递给您的函数。

从其他SOapply_along_axis来看不是很有用,并且经常出现问题。它并不比更明确的迭代快。 对于3d(或更高(,它可以使迭代(在"其他"两个轴上(更简单(但同样不是更快(。

对于 1dp,这更简单:

In [102]: [h(_e,3) for _e in p]                                                          
Out[102]: [0, 1.8605270777946112e-06, 2.4378506521338855e-05]

非迭代方法是使用布尔掩码来选择在计算中使用哪些p。 这样,您就不必使用标量if表达式:

In [106]: mask = p!=0                                                                    
In [107]: mask                                                                           
Out[107]: array([False,  True,  True])
In [108]: p1 = p[mask]                                                                   
In [109]: res = np.zeros(p.shape)                                                        
In [110]: res[mask] = -p1*logn(p1,3)                                                     
In [111]: res                                                                            
Out[111]: array([0.00000000e+00, 1.86052708e-06, 2.43785065e-05])

np.log这样的ufunc采用一个where参数,该参数可用于绕过错误的输入值:

In [114]: -p * np.log(p, where=(p!=0), out=np.zeros(p.shape))/np.log(3)                  
Out[114]: array([-0.00000000e+00,  1.86052708e-06,  2.43785065e-05])

最新更新