如何在没有显式模型.fit的情况下设置tf.keras.callbacks.ModelCheckpoint



我想在代码的转换过程中添加Checkpoint。通过在model.fit中设置callbacks = callbacks,我知道了方法。但是,在代码中,没有显式调用model.fitinsead到K.function,如下所示。有人能告诉我设置检查站的正确位置在哪里吗?完整的代码可以通过这个github链接查看。

vae_model = vae_util.create_vae(input_shape)
vae_model.compile(optimizer=opt, loss='mse')
rec_loss = vae_loss(vae_model.output, train_target)
total_loss = rec_loss
updates = opt.get_updates(total_loss, vae_model.trainable_weights)
iterate = K.function(vae_model.inputs + [train_target], [rec_loss], updates=updates)
eval_rec_loss = vae_loss(vae_model.output, test_target)
evaluate = K.function(vae_model.inputs + [test_target], [eval_rec_loss])   

原始代码在第139行中已经保存了安全点

相关内容

最新更新