Keras:train_on_batch有没有历史+进度的示例代码



在Python中,为了进行更精细的控制,我将从Keras的Model.fit循环转移到Model.train_on_batch循环。但是fit返回的进度条和History对象是有用的。在浪费时间从头开始实现它们之前,我想知道是否有人发现了使用train_on_batch复制进度条和历史的示例代码?

(注意。我看了fit的源代码,但有足够多的间接层,不容易准确地挖掘出它在做什么。我也发现了这一点,这很有帮助,但没有相关的功能。(

定义了EPOCHStrain_generator和验证数据val_x, val_y后,可以替换

history = model.fit(train_generator, validation_data = (val_x, val_y), epochs = EPOCHS)

带有以下代码:

callbacks = tf.keras.callbacks.CallbackList(
None, 
add_history = True,
add_progbar = True,
model = model,
epochs = EPOCHS,
verbose = 1,
steps = len(train_generator)
)
callbacks.on_train_begin()
for epoch in range(EPOCHS):
model.reset_metrics()
callbacks.on_epoch_begin(epoch)
for i in range(len(train_generator)):
callbacks.on_train_batch_begin(i)
logs = model.train_on_batch(*train_generator[i], reset_metrics = False, return_dict = True)              
callbacks.on_train_batch_end(i, logs)
validation_logs = model.evaluate(val_x, val_y, callbacks = callbacks, return_dict = True)
logs.update({'val_' + name: v for name, v in validation_logs.items()})
callbacks.on_epoch_end(epoch, logs)
train_generator.on_epoch_end()
callbacks.on_train_end(epoch_logs)
history = model.history

所以在查看了keras的源代码后,我发现tf.keras.callbacks.ProgbarLogger和tf.keras.allbacks.History就是您想要的

源代码

keras/callbacks.py#L259

keras/callbacks.py#L263

最新更新