检查值是否高于阈值(如果是),请替换为



我的目标是列出我的:

predictions = [0, 0.2, 0.9, 0.7]

如果

高于 0.5,则如果不是 1,则应变为 0

我试过了:

predictions = np.where(predictions>=0.5,1, 0).tolist()

但是当取第一个元素时,它是:

[0]

不仅0

我想做的事的最佳方式是什么?

只需使用 np.array.round 方法:

>>> predictions = np.array([0, 0.2, 0.9, 0.7])
>>> predictions.round().tolist()
[0, 0, 1, 1]
>>> 

如果你真的需要一个列表理解,请:

>>> predictions = [0, 0.2, 0.9, 0.7]
>>> [int(i >= 0.5) for i in predictions]
[0, 0, 1, 1]

您可以使用列表理解:

[0 if i>=0.5 else 1 for i in predictions]

注意:您说如果原始值高于 0.5,则希望条目为零。你确定吗?或者如果低于 0.5,它应该为零?

如果预测是一个列表,那么让它成为一个 numpy 数组,其中工作:

import numpy as np
predictions = [0.0, 0.2, 0.9, 0.7] 
newList = np.where(np.array(predictions) >= 0.5,1, 0).tolist()
print(newList)
print(newList[0])

输出:

[0, 0, 1, 1]
0

您正在尝试的内容会导致此错误:TypeError: '>=' not supported between instances of 'list' and 'float'

此错误是因为您将列表传递给 np.where 函数,该函数采用np.array

使用 np.array(predictions) 将列表转换为 numpy 数组可以解决此问题。

因此,总而言之,将您的行更改为此行,您将获得所需的输出。

predictions = np.where(np.array(predictions) >= 0.5, 1, 0).tolist()

假设你有预测是一个 numpy 数组

threshold = 0.5 
prediction=np.array([0,0.2,0.9,0.7])
prediction[prediction<=threshold]=0
prediction[prediction>threshold]=1

注意 - 根据需要更改>或>=符号

最新更新