如何从张量流检查点文件正确恢复网络训练?



我正在努力恢复一天的模型,但没有任何成功。我的代码由一个class TF_MLPRegressor()组成,我在构造函数中定义网络架构。然后我调用fit()函数进行训练。因此,这就是我在fit()函数中保存具有 1 个隐藏层的简单感知器模型的方式:

starting_epoch = 0
# Launch the graph
tf.set_random_seed(self.random_state)   # fix the random seed before creating the Session in order to take effect!
if hasattr(self, 'sess'):
self.sess.close()
del self.sess   # delete Session to release memory
gc.collect()
self.sess = tf.Session(config=self.config) # save the session to predict from new data
# Create a saver object which will save all the variables
saver = tf.train.Saver(max_to_keep=2)  # max_to_keep=2 means to not keep more than 2 checkpoint files
self.sess.run(tf.global_variables_initializer())
# ... (each 100 epochs)
saver.save(self.sess, self.checkpoint_dir+"/resume", global_step=epoch)

然后,我创建一个具有完全相同的输入参数值的新TF_MLPRegressor()实例,并调用fit()函数来恢复模型,如下所示:

self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
starting_epoch = int(ckpt.split('-')[-1])
metagraph = ".".join([ckpt, 'meta'])
saver = tf.train.import_meta_graph(metagraph)
self.sess.run(tf.global_variables_initializer())    # Initialize variables
lhl = tf.trainable_variables()[2]
lhlA = lhl.eval(session=self.sess)
saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model
lhlB = lhl.eval(session=self.sess)
print lhlA == lhlB

lhlAlhlB是恢复前后的最后一个隐藏层权重,根据我的代码,它们完全匹配,即保存的模型不会加载到会话中。我做错了什么?

我找到了解决方法!奇怪的是,元图不包含我定义的所有变量或为它们分配新名称。对于构造函数中的示例,我定义了将携带输入特征向量和实验值的张量:

self.x = tf.placeholder("float", [None, feat_num], name='x')
self.y = tf.placeholder("float", [None], name='y')

但是,当我tf.reset_default_graph()并加载元图时,我得到以下变量列表:

[
<tf.Variable 'Variable:0' shape=(300, 300) dtype=float32_ref>, 
<tf.Variable 'Variable_1:0' shape=(300,) dtype=float32_ref>, 
<tf.Variable 'Variable_2:0' shape=(300, 1) dtype=float32_ref>, 
<tf.Variable 'Variable_3:0' shape=(1,) dtype=float32_ref>
]

作为记录,每个输入要素向量具有 300 个要素。无论如何,当我后来尝试使用以下方法启动训练时:

_, c, p = self.sess.run([self.optimizer, self.cost, self.pred], 
feed_dict={self.x: batch_x, self.y: batch_y, self.isTrain: True})

我收到如下错误:

"TypeError: Cannot interpret feed_dict key as Tensor: Tensor 'x' is not an element of this graph."

因此,由于每次创建class TF_MLPRegressor()实例时,我都在构造函数中定义了网络架构,因此我决定不加载元图并且它起作用了!我不知道为什么 TF 不将所有变量保存到元图中,可能是因为我明确定义了网络架构(我不使用包装器或默认层(,如下例所示:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

总而言之,我按照第一条消息中的描述保存我的模型,但要恢复它们,我使用它:

saver = tf.train.Saver(max_to_keep=2)
self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model

最新更新