GridSearchCV 和树分类器



在这篇文章中提到了

param_grid = {'max_depth': np.arange(3, 10)}
tree = GridSearchCV(DecisionTreeClassifier(), param_grid)
tree.fit(xtrain, ytrain)
tree_preds = tree.predict_proba(xtest)[:, 1]
tree_performance = roc_auc_score(ytest, tree_preds)

Q1:一旦我们执行了上述步骤并获得最佳参数,我们需要用所有数据(训练 + 验证(和学习的参数拟合树吗?

Q2:参数中特别提到了max_depth,可以通过访问tree.best_params_来获取,那么网格找到的其他参数呢?怎么可能进入这些来建造一棵好树?

回答您的第一个问题,当您创建GridSearchCV对象时,您可以将参数refit设置为True(默认值为True(,它使用在整个数据集上找到的最佳参数返回估计器,并且可以通过best_estimator_属性访问它。它的行为类似于普通估计器,并像任何其他 sklearn 估计器一样支持.predict方法。

现在回答您的第二个问题,您可以访问决策树模型的所有参数,该模型用于使用best_estimator_属性本身拟合最终估计器,但正如我之前所说,您无需拟合具有最佳参数的新分类器,因为refit=True会为您完成。

请按照下面的示例代码更好地理解这一点:

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
X, y = make_classification(random_state=0)
param_grid = {'max_depth': np.arange(3, 10), 'min_samples_leaf':np.arange(2,10)}
tree = GridSearchCV(DecisionTreeClassifier(), param_grid)
tree.fit(X, y)
GridSearchCV(cv=None, error_score=nan,
estimator=DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None,
criterion='gini', max_depth=None,
max_features=None,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
min_impurity_split=None,
min_samples_leaf=1,
min_samples_split=2,
min_weight_fraction_leaf=0.0,
presort='deprecated',
random_state=None,
splitter='best'),
iid='deprecated', n_jobs=None,
param_grid={'max_depth': array([3, 4, 5, 6, 7, 8, 9]),
'min_samples_leaf': array([2, 3, 4, 5, 6, 7, 8, 9])},
pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
scoring=None, verbose=0)
# This is how your best estimator looks like
print(tree.best_estimator_)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
max_depth=3, max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=6, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=None, splitter='best')
# you can directly use it for prediction as shown below
tree.best_estimator_.predict(X) 
array([0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1,
0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0,
0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1,
1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1,
0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0])

最新更新