SkLearn 决策树在拟合后不包含数值特征



我正试图用以下代码适合SkLearn DecisionTree的数据框架。但是我得到一个错误Length of feature_names, 9 does not match number of features, 8。decision - tree似乎只适合经过一次编码转换的分类特征,而不是数值特征。如何在决策树模型中包含数值特征?

import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import make_pipeline
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import tree
from matplotlib import pyplot as plt
import graphviz 
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder,StandardScaler
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({'brand'      : ['aaaa', 'asdfasdf', 'sadfds', 'NaN'],
'category'   : ['asdf','asfa','asdfas','as'], 
'num1'       : [1, 1, 0, 0] ,
'target'     : [1,0,0,1]})
df
dtarget=df['target']
dfeatures=df.drop('target', axis=1)

num = dfeatures.select_dtypes(include=["int64"]).columns.tolist()
cat = dfeatures.select_dtypes(include=["object"]).columns.tolist()

transformer = ColumnTransformer(
transformers=[
("cat", OneHotEncoder(),  cat),
]
)

clf= DecisionTreeClassifier(criterion="entropy", max_depth = 5)

pipe = Pipeline(steps=[
('onehotenc', transformer),
('decisiontree', clf)
])
#Fit the training data to the pipeline
pipe.fit(dfeatures, dtarget)

pipe.named_steps['onehotenc'].get_feature_names_out().tolist(), 
dot_data= tree.export_graphviz(clf,
out_file=None,
feature_names = num + pipe.named_steps['onehotenc'].get_feature_names_out().tolist(), 
class_names= ['1', '0'],
filled = True)

数字特性不在您的转换器中。既然你不想对它做任何改变,试着让它通过。可以显式定义传递列,也可以传递其余列。如果您知道这是唯一可以发送到模型的其他列,则剩余部分是可以的。

transformer = ColumnTransformer(
transformers=[
("cat", OneHotEncoder(),  cat),

],remainder='passthrough'
)

您将看到您的特性名称包括num1

pipe.named_steps['onehotenc'].get_feature_names_out().tolist()

输出
['cat__brand_NaN',
'cat__brand_aaaa',
'cat__brand_asdfasdf',
'cat__brand_sadfds',
'cat__category_as',
'cat__category_asdf',
'cat__category_asdfas',
'cat__category_asfa',
'remainder__num1']

相关内容

  • 没有找到相关文章