如何使用scikit学习适合多维输出



我正在尝试在训练数据中拟合OneVsAll Classification输出,输出行加起来为1。

一种可能的方法是读取所有行,找出哪一列的值最高,并为训练准备数据。

例如:y = [[0.2,0.8,0],[0,1,0],[0,0.3,0.7]]可以简化为y = [b,b,c],将a,b,c分别视为列0,1,2的对应类。

scikit learn中是否有一个功能可以帮助实现这样的转换?

此代码可以执行您想要的操作:

import numpy as np
import string
y = np.array([[0.2,0.8,0],[0,1,0],[0,0.3,0.7]])
def transform(y,labels):
  f = np.vectorize(lambda i : string.letters[i])
  y = f(y.argmax(axis=1)) 
  return y
y = transform(y,'abc') 

编辑:使用alko的注释,我让它更通用,让用户为转换函数提供标签。

相关内容

  • 没有找到相关文章

最新更新