我正在尝试使用SciKit-Learn执行我的第一个KNN分类器。我一直在遵循用户指南和其他在线示例,但有一些事情我不确定。对于这篇文章,让我们使用下面的
X = dataY = target
-
在我读过的大多数机器学习介绍页面中,似乎都说你想要一个训练集、一个验证集和一个测试集。从我的理解来看,交叉验证允许您结合训练集和验证集来训练模型,然后您应该在测试集上对其进行测试以获得分数。然而,我在论文中看到,在很多情况下,你可以对整个数据集进行交叉验证,然后报告CV分数作为准确性。我明白,在理想情况下,你想在单独的数据上进行测试,但如果这是合法的,我想对我的整个数据集进行交叉验证,并报告这些分数
-
开始这个过程
我定义我的KNN分类器如下
knn = KNeighborsClassifier(algorithm = 'brute')
我使用
搜索最佳的n_neighborsclf = GridSearchCV(knn, parameters, cv=5)
现在如果我输入
clf.fit(X,Y)
我可以使用
检查最佳参数clf.best_params_
然后我可以得到一个分数
clf.score(X,Y)
但是-据我所知,这并没有交叉验证模型,因为它只给出1分?
如果我看过clf。best_params_ = 14
knn2 = KNeighborsClassifier(n_neighbors = 14, algorithm='brute')
cross_val_score(knn2, X, Y, cv=5)
现在我知道数据已经过交叉验证,但我不知道使用clf是否合法。拟合找到最佳参数,然后使用cross_val_score与新的KNN模型?
- 我理解这样做的"正确"方法如下
拆分为X_train, X_test, Y_train, Y_test比例火车组->对测试集应用转换
knn = KNeighborsClassifier(algorithm = 'brute')
clf = GridSearchCV(knn, parameters, cv=5)
clf.fit(X_train,Y_train)
clf.best_params_
然后我可以得到一个分数
clf.score(X_test,Y_test)
在这种情况下,是否使用最佳参数计算分数?
我希望这是有意义的。我一直在努力寻找尽可能多的东西,但我已经到了我认为更容易得到一些直接答案的地步。
在我的脑海中,我试图使用整个数据集获得一些交叉验证的分数,但也使用网格搜索(或类似的东西)来微调参数。
-
是的,你可以在你的整个数据集上进行CV,这是可行的,但我仍然建议你至少将你的数据分成两组,一组用于CV,另一组用于测试。
-
.score
函数应该根据文档返回单个float
值,这是best estimator
的分数(这是您从拟合GridSearchCV
中获得的最佳分数估计器)在给定的X,Y - 如果你看到最好的参数是14,那么你可以继续在你的模型中使用它,但是如果你给它更多的参数,你应该设置所有的参数。(-我这么说是因为你没有给出你的参数表)是的,再次检查你的简历是合法的,以防这个模型是否像它应该的那样好。
希望这能让事情更清楚:)
如果数据集很小,您可能没有足够的时间进行训练/测试分割。人们通常只根据交叉验证来估计模型的预测能力。在上面的代码中,GridSearchCV通过将训练集分成内部训练集(80%)和验证集(20%)来拟合模型(clf.fit(X, y)
)时执行5倍交叉验证。
您可以通过clf.cv_results_
访问模型性能指标,包括验证分数。您想要查看的指标包括mean_test_score
(在您的案例中,每个n_neighbor
都应该有一个分数)。您可能还想打开'mean_train_score',以了解模型是否过度拟合。请参阅下面的模型设置示例代码(注意,knn是一种非参数ML模型,对特征的规模很敏感,因此人们通常使用StandardScaler对特征进行标准化):
pipe = Pipeline([
('sc', StandardScaler()),
('knn', KNeighborsClassifier(algorithm='brute'))
])
params = {
'knn__n_neighbors': [3, 5, 7, 9, 11] # usually odd numbers
}
clf = GridSearchCV(estimator=pipe,
param_grid=params,
cv=5,
return_train_score=True) # Turn on cv train scores
clf.fit(X, y)
一个快速提示:样本数量的平方根通常是n_neighbor
的一个好选择,所以确保你在GridSearchCV中包含它。