我想用scikit-learn的cross_val_score()
函数对我的Keras神经网络进行交叉验证。
问题是,每次折叠后不仅会记住结果,还会记住整个 Keras 模型。因此,我想在每次折叠后使用K.clear_session()
清除此模型。但这只是上下文的详细信息。
我的主要问题是:如何在每次折叠后使用 scikit-learn 中的 cross_val_score() 运行自定义函数?换句话说:可以运行每次折叠后应该运行的回调?还是存在其他解决方法?
你可以创建一个自定义回调并重新编写这个回调的on_train_end(self,logs={})方法。这种新方法将在每个训练步骤结束时完成一些工作。像这样:
class CustomCall(Callback):
def __init__(self):
super(CustomCall, self).__init__()
def on_epoch_begin(self, epoch, logs={}):
return
def on_epoch_end(self, epoch, logs={}):
return
def on_batch_begin(self, batch, logs={}):
return
def on_train_end(self, logs={}):
# Stuff here
print('n Delete previous trained model : ')
K.clear_session()
return