我的目标是列出我的:
predictions = [0, 0.2, 0.9, 0.7]
高于 0.5,则如果不是 1,则应变为 0
我试过了:
predictions = np.where(predictions>=0.5,1, 0).tolist()
但是当取第一个元素时,它是:
[0]
不仅0
我想做的事的最佳方式是什么?
只需使用 np.array.round
方法:
>>> predictions = np.array([0, 0.2, 0.9, 0.7])
>>> predictions.round().tolist()
[0, 0, 1, 1]
>>>
如果你真的需要一个列表理解,请:
>>> predictions = [0, 0.2, 0.9, 0.7]
>>> [int(i >= 0.5) for i in predictions]
[0, 0, 1, 1]
您可以使用列表理解:
[0 if i>=0.5 else 1 for i in predictions]
注意:您说如果原始值高于 0.5,则希望条目为零。你确定吗?或者如果低于 0.5,它应该为零?
如果预测是一个列表,那么让它成为一个 numpy 数组,其中工作:
import numpy as np
predictions = [0.0, 0.2, 0.9, 0.7]
newList = np.where(np.array(predictions) >= 0.5,1, 0).tolist()
print(newList)
print(newList[0])
输出:
[0, 0, 1, 1]
0
您正在尝试的内容会导致此错误:TypeError: '>=' not supported between instances of 'list' and 'float'
此错误是因为您将列表传递给 np.where
函数,该函数采用np.array
。
使用 np.array(predictions)
将列表转换为 numpy 数组可以解决此问题。
因此,总而言之,将您的行更改为此行,您将获得所需的输出。
predictions = np.where(np.array(predictions) >= 0.5, 1, 0).tolist()
假设你有预测是一个 numpy 数组
threshold = 0.5
prediction=np.array([0,0.2,0.9,0.7])
prediction[prediction<=threshold]=0
prediction[prediction>threshold]=1
注意 - 根据需要更改>或>=符号