有人可以解释一下predict()
方法在 scikit learn 的 kmeans 实现中有什么用吗?官方文档指出其用途为:
预测 X 中每个样本所属的最接近聚类。
但是我也可以通过在该方法上训练模型来获取输入集 X 的每个样本fit_transform()
聚类编号/标签。那么predict()
方法有什么用呢?它是否应该为看不见的数据指出最近的集群?如果是,那么如果执行降维度量(如 SVD),如何处理新数据点?
这是一个类似的问题,但我仍然认为它没有真正的帮助。
predict() 方法有什么用?它是否应该为看不见的数据指出最近的集群?
是的,没错。
那么,如果您执行降维度量(例如 SVD),如何处理新数据点?
在将未看到的数据传递给.predict()
之前,您可以对未看到的数据应用相同的降维方法。下面是一个典型的工作流程:
# prerequisites:
# x_train: training data
# x_test: "unseen" testing data
# km: initialized `KMeans()` instance
# dr: initialized dimensionality reduction instance (such as `TruncatedSVD()`)
# fitting
x_dr = dr.fit_transform(x_train)
y = km.fit_predict(x_dr)
# ...
# working with unseen data (models have been fitted before)
x_dr = dr.transform(x_test)
y = km.predict(x_dr)
# ...
实际上,诸如fit_transform
和fit_predict
之类的方法是为了方便。y = km.fit_predict(x)
相当于y = km.fit(x).predict(x)
。
我认为如果我们按如下方式编写配件部分,更容易看到发生了什么:
# fitting
dr.fit(x_train)
x_dr = dr.transform(x_train)
km.fit(x_dr)
y = km.predict(x_dr)
除了调用.fit()
在拟合期间和未见过的数据中平等使用的模型。
总结:
.fit()
的目的是使用数据训练模型。.predict()
或.transform()
的目的是将经过训练的模型应用于数据。- 如果要拟合模型并在训练期间将其应用于相同的数据,为方便起见,可以使用
.fit_predict()
或.fit_transform()
。 - 链接多个模型(如降维和聚类)时,请在拟合和测试期间以相同的顺序应用它们。