如何根据条件有效地将函数应用于数组中的值



我有一个数组arorg这样:

import numpy as np
arorg = np.array([[-1., 2., -4.], [0.5, -1.5, 3]])

和另一个数组values,如下:

values = np.array([1., 0., 2.])

values具有与arorg相同的条目。

现在,我想将功能应用于条目或arorg,具体取决于它们是正还是负面的:

def neg_fun(val1, val2):
    return val1 / (val1 + abs(val2))
def pos_fun(val1, val2):
    return 1. / ((val1 / val2) + 1.)

因此, val2arorgval1中的(绝对(值 - 这是棘手的部分 - 来自 values;如果我将pos_funneg_fun应用于arorg中的i,则val1应为values[i]

我目前实现了以下内容:

ar = arorg.copy()
for (x, y) in zip(*np.where(ar > 0)):
    ar.itemset((x, y), pos_fun(values[y], ar.item(x, y)))
for (x, y) in zip(*np.where(ar < 0)):
    ar.itemset((x, y), neg_fun(values[y], ar.item(x, y)))

给我所需的输出:

array([[ 0.5       ,  1.        ,  0.33333333],
       [ 0.33333333,  0.        ,  0.6       ]])

我经常进行这些计算,我想知道是否有更有效的方法可以这样做。像

np.where(arorg > 0, pos_fun(xxxx), arorg)

很棒,但我不知道如何正确地通过参数(xxx(。有任何建议吗?

,如问题所示,这是使用np.where

首先,我们使用函数实现的直接翻译来生成正案例和负面案例的值/数组。然后,使用一个正值的掩码,我们将使用np.where在这两个数组之间进行选择。

因此,实现将符合这些行 -

# Get positive and negative values for all elements
val1 = values
val2 = arorg
neg_vals = val1 / (val1 + np.abs(val2))
pos_vals = 1. / ((val1 / val2) + 1.)
# Get a positive mask and choose between positive and negative values 
pos_mask = arorg > 0
out = np.where(pos_mask, pos_vals, neg_vals)

您不需要将功能应用于数组的拉链元素,您可以通过简单的数组操作和切片来完成相同的事情。

首先,获取正面和负计算,保存为数组。然后创建一个零的返回数组(作为默认值(,并使用posneg的布尔片填充它:

import numpy as np
arorg = np.array([[-1., 2., -4.], [0.5, -1.5, 3]])
values = np.array([1., 0., 2.])
pos = 1. / ((values / arorg) + 1)
neg = values / (values + np.abs(arorg))
ret = np.zeros_like(arorg)
ret[arorg>0] = pos[arorg>0]
ret[arorg<=0] = neg[arorg<=0]
ret
# returns:
array([[ 0.5       ,  1.        ,  0.33333333],
       [ 0.33333333,  0.        ,  0.6       ]])
import numpy as np
arorg = np.array([[-1., 2., -4.], [0.5, -1.5, 3]])
values = np.array([1., 0., 2.])
p = 1.0/(values/arorg+1)
n = values/(values+abs(arorg))
#using np.place to extract negative values and put them to p
np.place(p,arorg<0,n[arorg<0])
print(p)
[[ 0.5         1.          0.33333333]
 [ 0.33333333  0.          0.6       ]]

最新更新