KeyError: '无法格式化此回调文件路径: 原因: \'lr\



我最近从Tensorflow 2.2.0切换到2.4.1,现在我有一个ModelCheckpoint回调路径的问题。如果我使用tf 2.2的环境,这段代码可以正常工作,但当我使用tf 2.4.1时,会出现错误。

checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint])

错误:

KeyError: 'Failed to format this callback filepath: "path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}"原因:lr">

ModelCheckpoint中,filepath参数的格式化名称,只能包含:epoch+纪元结束后logs中的键.

您可以在日志中看到这样的可用键:

class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("Log keys: {}".format(keys))
model.fit(..., callbacks=[CustomCallback()])

如果你运行上面的代码,你会看到这样的内容:

Log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']

显示您可以使用的可用键(加上epoch)和lr不可用(您在filepath名称中使用了3个密钥:epoch,lrval_loss)。


解决方案:

你可以自己添加学习速率到日志中:

import tensorflow.keras.backend as K
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
logs.update({'lr': K.eval(self.model.optimizer.lr)})
keys = list(logs.keys())
print("Log keys: {}".format(keys)) #you will see now `lr` available
checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint, CustomCallback()])

相关内容

  • 没有找到相关文章

最新更新