在Python中,为了进行更精细的控制,我将从Keras的Model.fit
循环转移到Model.train_on_batch
循环。但是fit
返回的进度条和History对象是有用的。在浪费时间从头开始实现它们之前,我想知道是否有人发现了使用train_on_batch
复制进度条和历史的示例代码?
(注意。我看了fit
的源代码,但有足够多的间接层,不容易准确地挖掘出它在做什么。我也发现了这一点,这很有帮助,但没有相关的功能。(
定义了EPOCHS
、train_generator
和验证数据val_x, val_y
后,可以替换
history = model.fit(train_generator, validation_data = (val_x, val_y), epochs = EPOCHS)
带有以下代码:
callbacks = tf.keras.callbacks.CallbackList(
None,
add_history = True,
add_progbar = True,
model = model,
epochs = EPOCHS,
verbose = 1,
steps = len(train_generator)
)
callbacks.on_train_begin()
for epoch in range(EPOCHS):
model.reset_metrics()
callbacks.on_epoch_begin(epoch)
for i in range(len(train_generator)):
callbacks.on_train_batch_begin(i)
logs = model.train_on_batch(*train_generator[i], reset_metrics = False, return_dict = True)
callbacks.on_train_batch_end(i, logs)
validation_logs = model.evaluate(val_x, val_y, callbacks = callbacks, return_dict = True)
logs.update({'val_' + name: v for name, v in validation_logs.items()})
callbacks.on_epoch_end(epoch, logs)
train_generator.on_epoch_end()
callbacks.on_train_end(epoch_logs)
history = model.history
所以在查看了keras的源代码后,我发现tf.keras.callbacks.ProgbarLogger和tf.keras.allbacks.History就是您想要的
源代码
keras/callbacks.py#L259
keras/callbacks.py#L263