将最佳检查点恢复为估计器tensorflow 2.x



简单地说,我使用tensorflow数据集API设置了一个数据输入管道。然后,我使用keras实现了一个用于分类的CNN模型,并将其转换为估计器。我向我的估计器Train和Eval Specs提供了我的input_fn,为训练和评估提供了输入数据。作为最后一步,我启动了tf.estimator.train_and_evaluate的模型训练

def my_input_fn(tfrecords_path):
dataset = (...)
return batch_fbanks, batch_labels
def build_model():
model = tf.keras.models.Sequential()
model.add(...)
model.compile(...)
return model
model = build_model()
run_config=tf.estimator.RunConfig(model_dir,save_summary_steps=100,save_checkpoints_steps=1000)
estimator = tf.keras.estimator.model_to_estimator(model,config=run_config)
def serving_input_receiver_fn():
inputs = {'Conv1_input': tf.compat.v1.placeholder(shape=[None, 11,120,1], dtype=tf.float32)}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
exporter = tf.estimator.BestExporter(serving_input_receiver_fn, name="best_exporter", exports_to_keep=5)
train_spec_dnn = tf.estimator.TrainSpec(input_fn = lambda: my_input_fn(train_data_path),hooks=[hook])
eval_spec_dnn = tf.estimator.EvalSpec(input_fn = lambda: my_eval_input_fn(eval_data_path),exporters=exporter,start_delay_secs=0,throttle_secs=15)
tf.estimator.train_and_evaluate(estimator, train_spec_dnn, eval_spec_dnn)

我使用tf.estimator.BestExporter保存了5个最好的检查点,如上所示。一旦我完成了训练,我想重新加载最佳模型,并将其转换为估计器,以重新评估模型并在新的数据集上进行预测。然而,我的问题是将检查点恢复为估计器。我尝试了几种解决方案,但每次我没有得到估计器对象时,我都需要运行它的evaluatepredict方法。

为了详细说明,每个最佳检查点目录的组织如下:

./
variables/
variables.data-00000-of-00002
variables.data-00001-of-00002
variables.index
saved_model.pb

所以问题是,我如何从最佳检查点获得估计器对象,以便使用它来评估我的模型并对新数据进行预测?

注意:我发现一些基于TensorFlow v1功能的解决方案无法解决我的问题,因为我使用TF v2。

非常感谢,任何帮助都将不胜感激。

您可以使用下面从tf.assembly.BestExporter创建的类

它所做的是,除了保存最好的模型(.pb文件等(外,它还将保存其他文件夹上的最佳导出模型检查点

以下是类别:

import shutil, glob, os
# import tensorflow.logging as logging
## the path where all the checkpoint reside
BEST_CHECKPOINTS_PATH_FROM = 'PATH TO ALL CHECKPOINT FILES'
## the path it will save the best exporter checkpoint files
BEST_CHECKPOINTS_PATH_TO = 'PATH TO BEST EXPORTER CHECKPOINT FILES TO BE SAVE' 
class BestCheckpointsExporter(tf.estimator.BestExporter):
def export(self, estimator, export_path, checkpoint_path, eval_result,is_the_final_export):
if self._best_eval_result is None or 
self._compare_fn(self._best_eval_result, eval_result):
#print('Exporting a better model ({} instead of {})...'.format(eval_result, self._best_eval_result))
for name in glob.glob(checkpoint_path + '.*'):
print(name)
print(os.path.join(BEST_CHECKPOINTS_PATH_TO, os.path.basename(name)))
shutil.copy(name, os.path.join(BEST_CHECKPOINTS_PATH_TO, os.path.basename(name)))
# also save the text file used by the estimator api to find the best checkpoint
with open(os.path.join(BEST_CHECKPOINTS_PATH_TO, "checkpoint"), 'w') as f:
f.write("model_checkpoint_path: "{}"".format(os.path.basename(checkpoint_path)))
self._best_eval_result = eval_result
else:
print('Keeping the current best model ({} instead of {}).'.format(self._best_eval_result, eval_result))

类的示例用法
您只需通过调用类并传递serving_input_receiver_fn来替换导出器。

def serving_input_receiver_fn():
inputs = {'my_dense_input': tf.compat.v1.placeholder(shape=[None, 4], dtype=tf.float32)}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
exporter = BestCheckpointsExporter(serving_input_receiver_fn=serving_input_receiver_fn) 
train_spec_dnn = tf.estimator.TrainSpec(input_fn = input_fn, max_steps=5)
eval_spec_dnn = tf.estimator.EvalSpec(input_fn=input_fn,exporters=exporter,start_delay_secs=0,throttle_secs=15)
(x, y) =  tf.estimator.train_and_evaluate(keras_estimator, train_spec_dnn, eval_spec_dnn)

此时,它将在您指定的文件夹中保存导出的最佳模型检查点文件。

要加载检查点文件,您需要执行以下步骤:
步骤1:重建模型实例

def build_model():
model = tf.keras.models.Sequential()
model.add(...)
model.compile(...)
return model
model = build_model()

步骤2:使用模型load_weights API
参考URL:https://www.tensorflow.org/tutorials/keras/save_and_load

ck_path = tf.train.latest_checkpoint('PATH TO BEST EXPORTER CHECKPOINT FILES')
model.load_weights(ck_path)
## From there you will be able to call the predict & evaluate the functionality of the trained model
##PREDICT
prediction = model.predict(x)
##EVALUATE
for features_batch, labels_batch in input_fn().take(1):
model.evaluate(features_batch, labels_batch)

注意:所有这些都在谷歌colab上进行了模拟。

最新更新