如何在 scikit-learn 的 cross_val_score() 中每次折叠后运行函数



我想用scikit-learn的cross_val_score()函数对我的Keras神经网络进行交叉验证。

问题是,每次折叠后不仅会记住结果,还会记住整个 Keras 模型。因此,我想在每次折叠后使用K.clear_session()清除此模型。但这只是上下文的详细信息。

我的主要问题是:如何在每次折叠后使用 scikit-learn 中的 cross_val_score() 运行自定义函数?换句话说:可以运行每次折叠后应该运行的回调?还是存在其他解决方法?

你可以创建一个自定义回调并重新编写这个回调的on_train_end(self,logs={})方法。这种新方法将在每个训练步骤结束时完成一些工作。像这样:

class CustomCall(Callback):
    def __init__(self):
        super(CustomCall, self).__init__()
    def on_epoch_begin(self, epoch, logs={}):
        return
    def on_epoch_end(self, epoch, logs={}):
        return
    def on_batch_begin(self, batch, logs={}):
        return
    def on_train_end(self, logs={}):
        # Stuff here
        print('n Delete previous trained model : ')
        K.clear_session()
        return

相关内容

  • 没有找到相关文章

最新更新