如何确定GridSearchCV中每个评分指标的最佳参数和最佳分数



我正在尝试评估多个评分指标,以确定模型性能的最佳参数。也就是说:

要最大化F1,我应该使用这些参数。为了最大限度地提高精度应该使用这些参数。

我正在从这个sklearn页面中完成以下示例

import numpy as np
from sklearn.datasets import make_hastie_10_2
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
X, y = make_hastie_10_2(n_samples=5000, random_state=42)

scoring = {'PRECISION': 'precision', 'F1': 'f1'}
gs = GridSearchCV(DecisionTreeClassifier(random_state=42),
param_grid={'min_samples_split': range(2, 403, 10)},
scoring=scoring, refit='F1', return_train_score=True)
gs.fit(X, y)
best_params = gs.best_params_
best_estimator = gs.best_estimator_
print(best_params)
print(best_estimator)

哪个收益率:

{'min_samples_split': 62}
DecisionTreeClassifier(min_samples_split=62, random_state=42)

然而,我要寻找的是找到每个度量结果,因此在这种情况下,对于F1精度

如何在GridSearchCV中获得每种评分指标的最佳参数?

注意-我相信这与我对refit='F1'的使用有关,但不确定如何在那里使用多个度量?

要做到这一点,您必须深入研究整个网格搜索CV过程的详细结果;幸运的是,这些详细的结果在GridSearchCV对象(docs(的cv_results_属性中返回。

我已经按原样重新运行了您的代码,但我不会在这里重新键入;可以说,尽管明确地设置了随机数生成器的种子,但我得到了不同的最终结果(我想是由于版本不同(:

{'min_samples_split': 322}
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
max_depth=None, max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=322,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=42, splitter='best')

但这对于当前的问题来说并不重要。

使用返回的cv_results_字典的最简单方法是将其转换为pandas数据帧:

import pandas as pd
cv_results = pd.DataFrame.from_dict(gs.cv_results_)

尽管如此,由于它包含了太多的信息(列(,我将在这里进一步简化它来演示这个问题(请自行更全面地探索(:

df = cv_results[['params', 'mean_test_PRECISION', 'rank_test_PRECISION', 'mean_test_F1', 'rank_test_F1']]
pd.set_option("display.max_rows", None, "display.max_columns", None)
pd.set_option('expand_frame_repr', False)
print(df)

结果:

params  mean_test_PRECISION  rank_test_PRECISION  mean_test_F1  rank_test_F1
0     {'min_samples_split': 2}             0.771782                    1      0.763041            41
1    {'min_samples_split': 12}             0.768040                    2      0.767331            38
2    {'min_samples_split': 22}             0.767196                    3      0.776677            29
3    {'min_samples_split': 32}             0.760282                    4      0.773634            32
4    {'min_samples_split': 42}             0.754572                    8      0.777967            26
5    {'min_samples_split': 52}             0.754034                    9      0.777550            27
6    {'min_samples_split': 62}             0.758131                    5      0.773348            33
7    {'min_samples_split': 72}             0.756021                    6      0.774301            30
8    {'min_samples_split': 82}             0.755612                    7      0.768065            37
9    {'min_samples_split': 92}             0.750527                   10      0.771023            34
10  {'min_samples_split': 102}             0.741016                   11      0.769896            35
11  {'min_samples_split': 112}             0.740965                   12      0.765353            39
12  {'min_samples_split': 122}             0.731790                   13      0.763620            40
13  {'min_samples_split': 132}             0.723085                   14      0.768605            36
14  {'min_samples_split': 142}             0.713345                   15      0.774117            31
15  {'min_samples_split': 152}             0.712958                   16      0.776721            28
16  {'min_samples_split': 162}             0.709804                   17      0.778287            24
17  {'min_samples_split': 172}             0.707080                   18      0.778528            22
18  {'min_samples_split': 182}             0.702621                   19      0.778516            23
19  {'min_samples_split': 192}             0.697630                   20      0.778103            25
20  {'min_samples_split': 202}             0.693011                   21      0.781047            10
21  {'min_samples_split': 212}             0.693011                   21      0.781047            10
22  {'min_samples_split': 222}             0.693011                   21      0.781047            10
23  {'min_samples_split': 232}             0.692810                   24      0.779705            13
24  {'min_samples_split': 242}             0.692810                   24      0.779705            13
25  {'min_samples_split': 252}             0.692810                   24      0.779705            13
26  {'min_samples_split': 262}             0.692810                   24      0.779705            13
27  {'min_samples_split': 272}             0.692810                   24      0.779705            13
28  {'min_samples_split': 282}             0.692810                   24      0.779705            13
29  {'min_samples_split': 292}             0.692810                   24      0.779705            13
30  {'min_samples_split': 302}             0.692810                   24      0.779705            13
31  {'min_samples_split': 312}             0.692810                   24      0.779705            13
32  {'min_samples_split': 322}             0.688417                   33      0.782772             1
33  {'min_samples_split': 332}             0.688417                   33      0.782772             1
34  {'min_samples_split': 342}             0.688417                   33      0.782772             1
35  {'min_samples_split': 352}             0.688417                   33      0.782772             1
36  {'min_samples_split': 362}             0.688417                   33      0.782772             1
37  {'min_samples_split': 372}             0.688417                   33      0.782772             1
38  {'min_samples_split': 382}             0.688417                   33      0.782772             1
39  {'min_samples_split': 392}             0.688417                   33      0.782772             1
40  {'min_samples_split': 402}             0.688417                   33      0.782772             1

列的名称应该是不言自明的;它们包括所尝试的参数、所使用的每一个度量的分数以及相应的秩(1表示最佳(。例如,您可以立即看到,尽管'min_samples_split': 322确实给出了最好的F1分数,但它不是唯一的参数设置,而且还有更多的设置也给出了最佳F1分数和结果中1的相应rank_test_F1

从这一点来看,获得你想要的信息是微不足道的;例如,以下是两个度量中每一个度量的最佳模型:

print(df.loc[df['rank_test_PRECISION']==1]) # best precision
# result:
params  mean_test_PRECISION  rank_test_PRECISION  mean_test_F1  rank_test_F1
0  {'min_samples_split': 2}             0.771782                    1      0.763041            41
print(df.loc[df['rank_test_F1']==1]) # best F1
# result:
params  mean_test_PRECISION  rank_test_PRECISION  mean_test_F1  rank_test_F1
32  {'min_samples_split': 322}             0.688417                   33      0.782772             1
33  {'min_samples_split': 332}             0.688417                   33      0.782772             1
34  {'min_samples_split': 342}             0.688417                   33      0.782772             1
35  {'min_samples_split': 352}             0.688417                   33      0.782772             1
36  {'min_samples_split': 362}             0.688417                   33      0.782772             1
37  {'min_samples_split': 372}             0.688417                   33      0.782772             1
38  {'min_samples_split': 382}             0.688417                   33      0.782772             1
39  {'min_samples_split': 392}             0.688417                   33      0.782772             1
40  {'min_samples_split': 402}             0.688417                   33      0.782772             1

相关内容

  • 没有找到相关文章

最新更新