我有一个numpy数组r
,我需要计算一个标量函数,我们说np.sqrt(1-x**2)
对这个数组的每个元素x
。然而,我想返回函数的值为零,只要x>1
,和x
上的函数的值,否则。最后的结果应该是一个numpy的标量数组。
我怎样才能用最python的方式来写呢?
你可以用numpy.where(condition,if condition holds,otherwise)
这样np.where(x>1,0,np.sqrt(1-x**2))
就是答案
与普通的求值:
In [19]: f=lambda x:np.sqrt(1-x**2)
In [20]: f(np.linspace(0,2,10))
C:UserspaulAppDataLocalTempipykernel_16041368662409.py:1: RuntimeWarning: invalid value encountered in sqrt
f=lambda x:np.sqrt(1-x**2)
Out[20]:
array([1. , 0.97499604, 0.89580642, 0.74535599, 0.45812285,
nan, nan, nan, nan, nan])
生成nan
和x>1
的警告。其他答案建议的np.where
可以将nan
更改为0
,但您仍然会收到警告。警告可以消音。但另一种选择是使用np.sqrt
(和其他ufunc
)的where
参数:
In [21]: f=lambda x:np.sqrt(1-x**2, where=x<=1, out=np.zeros_like(x))
In [22]: f(np.linspace(0,2,10))
Out[22]:
array([1. , 0.97499604, 0.89580642, 0.74535599, 0.45812285,
0. , 0. , 0. , 0. , 0. ])
y = np.where(
r>1, #if r[i]>1:
0, #y[i]=0
np.sqrt(1-r**2) #else: y[i] = (1-r[i]**2)
)