我最近从Tensorflow 2.2.0切换到2.4.1,现在我有一个ModelCheckpoint
回调路径的问题。如果我使用tf 2.2的环境,这段代码可以正常工作,但当我使用tf 2.4.1时,会出现错误。
checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint])
错误:
KeyError: 'Failed to format this callback filepath: "path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}"原因:lr">
ModelCheckpoint
中,filepath
参数的格式化名称,只能包含:epoch
+纪元结束后logs
中的键.
您可以在日志中看到这样的可用键:
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("Log keys: {}".format(keys))
model.fit(..., callbacks=[CustomCallback()])
如果你运行上面的代码,你会看到这样的内容:
Log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
显示您可以使用的可用键(加上epoch
)和lr
不可用(您在filepath
名称中使用了3个密钥:epoch
,lr
和val_loss
)。
解决方案:
你可以自己添加学习速率到日志中:
import tensorflow.keras.backend as K
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
logs.update({'lr': K.eval(self.model.optimizer.lr)})
keys = list(logs.keys())
print("Log keys: {}".format(keys)) #you will see now `lr` available
checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint, CustomCallback()])