如何使用cross_val_score从GridSearch中获取best_estimator参数



为了方便起见,当我使用带有cross_val_score的嵌套交叉验证时,我想知道GridSearch的结果。

使用cross_val_score时,会得到一个分数数组。接收拟合的估计器或该估计器的所选参数的摘要将是有用的。

我知道你可以自己做这件事,但只需手动实现交叉验证,但如果可以与cross_val_score一起完成,会方便得多。

有没有办法做到这一点,或者这是一个建议的功能?

scikit learn中的GridSearchCV类已经在内部进行了交叉验证。您可以传递任何CV迭代器作为GridSearchCV构造函数的cv参数。

问题的答案是,这是一个值得推荐的功能。不幸的是,您无法使用cross_val_score(截至目前,scikit 0.14)获得适合嵌套交叉验证的模型的最佳参数

参见此示例:

from sklearn import datasets
from sklearn.linear_model import LinearRegression
from sklearn.grid_search import GridSearchCV
from sklearn.cross_validation import cross_val_score
digits = datasets.load_digits()
X = digits.data
y = digits.target
hyperparams = [{'fit_intercept':[True, False]}]
algo = LinearRegression()
grid = GridSearchCV(algo, hyperparams, cv=5, scoring='mean_squared_error')
# Nested cross validation
cross_val_score(grid, X, y)
grid.best_score_
[Out]:
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-4c4ac83c58fb> in <module>()
     15 # Nested cross validation
     16 cross_val_score(grid, X, y)
---> 17 grid.best_score_
AttributeError: 'GridSearchCV' object has no attribute 'best_score_'

(还要注意,你从cross_val_score中得到的分数不是scoring中定义的分数,这里是均方误差。你看到的是最佳估计器的分数函数。v0.14的错误在这里描述。)

sklearn v0.20.0(将于2018年底发布)中,如果需要,函数cross_validate会暴露经过训练的估计量。

请参阅此处新功能的相应拉取请求。这样的东西会起作用:

from sklearn.metrics.scorer import check_scoring
from sklearn.model_selection import cross_validate
scorer = check_scoring(estimator=gridSearch, scoring=scoring)
cvRet = cross_validate(estimator=gridSearch, X=X, y=y,
                       scoring={'score': scorer}, cv=cvOuter,
                       return_train_score=False,
                       return_estimator=True,
                       n_jobs=nJobs)
scores = cvRet['test_score']  # Equivalent to output of cross_val_score()
estimators = cvRet['estimator']

如果是return_estimator=True,则可以从返回的字典中检索估计量作为cvRet['estimator']。存储在CCD_ 11中的列表相当于CCD_。请参见此处cross_val_score()是如何通过cross_validate()实现的。

相关内容

  • 没有找到相关文章

最新更新