我想只在损失越来越低时才保存权重,并在评估时重用它们。
lowest_loss = Inf
if loss[round] < lowest_loss:
lowest_loss = loss[round]
model_weights = transfer_learning_iterative_process.get_model_weights(state)
eval_metric = federated_eval(model_weights, [fed_valid_data])
地点:
federated_eval = tff.learning.build_federated_evaluation(model_fn)
是否有一种可能的方法来保存hdf5格式的服务器权重或作为检查点和重用它?
是的,这可以通过TFF中的helper来完成。一般来说,这种功能是由tff.program.ProgramStateManagers
实现的。可以在这里找到一个保存到文件系统的实现,并且可以在tff.simulation.run_training_process
的实现中找到示例用法。