我正试图用以下代码适合SkLearn DecisionTree的数据框架。但是我得到一个错误Length of feature_names, 9 does not match number of features, 8
。decision - tree似乎只适合经过一次编码转换的分类特征,而不是数值特征。如何在决策树模型中包含数值特征?
import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import make_pipeline
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import tree
from matplotlib import pyplot as plt
import graphviz
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder,StandardScaler
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({'brand' : ['aaaa', 'asdfasdf', 'sadfds', 'NaN'],
'category' : ['asdf','asfa','asdfas','as'],
'num1' : [1, 1, 0, 0] ,
'target' : [1,0,0,1]})
df
dtarget=df['target']
dfeatures=df.drop('target', axis=1)
num = dfeatures.select_dtypes(include=["int64"]).columns.tolist()
cat = dfeatures.select_dtypes(include=["object"]).columns.tolist()
transformer = ColumnTransformer(
transformers=[
("cat", OneHotEncoder(), cat),
]
)
clf= DecisionTreeClassifier(criterion="entropy", max_depth = 5)
pipe = Pipeline(steps=[
('onehotenc', transformer),
('decisiontree', clf)
])
#Fit the training data to the pipeline
pipe.fit(dfeatures, dtarget)
pipe.named_steps['onehotenc'].get_feature_names_out().tolist(),
dot_data= tree.export_graphviz(clf,
out_file=None,
feature_names = num + pipe.named_steps['onehotenc'].get_feature_names_out().tolist(),
class_names= ['1', '0'],
filled = True)
数字特性不在您的转换器中。既然你不想对它做任何改变,试着让它通过。可以显式定义传递列,也可以传递其余列。如果您知道这是唯一可以发送到模型的其他列,则剩余部分是可以的。
transformer = ColumnTransformer(
transformers=[
("cat", OneHotEncoder(), cat),
],remainder='passthrough'
)
您将看到您的特性名称包括num1
列
pipe.named_steps['onehotenc'].get_feature_names_out().tolist()
输出['cat__brand_NaN',
'cat__brand_aaaa',
'cat__brand_asdfasdf',
'cat__brand_sadfds',
'cat__category_as',
'cat__category_asdf',
'cat__category_asdfas',
'cat__category_asfa',
'remainder__num1']