我正在尝试在scikit-learn中结合递归特征消除和网格搜索。正如您可以从下面的代码中看到的那样(它可以工作),我能够从网格搜索中获得最佳估计器,然后将该估计器传递给RFECV。然而,我宁愿先做RFECV,然后再做网格搜索。问题是,当我将选择器从RFECV传递到网格搜索时,它不接受它:
ValueError:估计器RFECV的参数引导无效
是否有可能从RFECV获得选择器并将其直接传递给RandomizedSearchCV,或者这在程序上不是正确的事情?
from sklearn.datasets import make_classification
from sklearn.feature_selection import RFECV
from sklearn.grid_search import GridSearchCV, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from scipy.stats import randint as sp_randint
# Build a classification task using 3 informative features
X, y = make_classification(n_samples=1000, n_features=25, n_informative=5, n_redundant=2, n_repeated=0, n_classes=8, n_clusters_per_class=1, random_state=0)
grid = {"max_depth": [3, None],
"min_samples_split": sp_randint(1, 11),
"min_samples_leaf": sp_randint(1, 11),
"bootstrap": [True, False],
"criterion": ["gini", "entropy"]}
estimator = RandomForestClassifierCoef()
clf = RandomizedSearchCV(estimator, param_distributions=grid, cv=7)
clf.fit(X, y)
estimator = clf.best_estimator_
selector = RFECV(estimator, step=1, cv=4)
selector.fit(X, y)
selector.grid_scores_
最好的方法是将RFECV嵌套在随机搜索中,使用来自此SO答案的方法。基于上述问题代码和SO答案的一些示例代码:
from sklearn.datasets import make_classification
from sklearn.feature_selection import RFECV
from sklearn.grid_search import GridSearchCV, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from scipy.stats import randint as sp_randint
# Build a classification task using 5 informative features
X, y = make_classification(n_samples=1000, n_features=25, n_informative=5, n_redundant=2, n_repeated=0, n_classes=8, n_clusters_per_class=1, random_state=0)
grid = {"estimator__max_depth": [3, None],
"estimator__min_samples_split": sp_randint(1, 11),
"estimator__min_samples_leaf": sp_randint(1, 11),
"estimator__bootstrap": [True, False],
"estimator__criterion": ["gini", "entropy"]}
estimator = RandomForestClassifier()
selector = RFECV(estimator, step=1, cv=4)
clf = RandomizedSearchCV(selector, param_distributions=grid, cv=7)
clf.fit(X, y)
print(clf.grid_scores_)
print(clf.best_estimator_.n_features_)