如何在 sklearn 中交叉验证 keras 包装器估计器时访问历史对象



我想在交叉验证中查看每个拆分的损失/错误进度。 keras.wrappers.scikit_learn.KerasClassifier 的 fit 方法返回一个包含我想要的数据的 history 对象,但在 sklearn.model_selection.cross_validate 变量方法中运行时无法访问它。

如何访问每个拆分中每个纪元的历史对象?

您可能可以使用 CSVLogger 回调来访问完整的历史记录。设置 CSVLogger 回调很容易,它将以您指定的任何文件名记录:{epoch、acc、loss、val_acc、val_loss}。

在我的代码中,我执行以下操作:

keras_classifier.fit(X, y, groups=None, 
    callbacks=[keras.callbacks.CSVLogger(filename, append=True)])

设置append=True应确保所有拆分的所有数据都包含在文件中。

需要考虑的事项:

  • 我不确定这是否适用于n_jobs=-1(用于在多个处理器上分配处理(,但是如果您运行单线程,它应该可以工作。
  • 请务必在运行分类器之前(或在初始化期间(删除该文件,以避免无限期地追加到该文件。

最新更新