TF LSTM:以后从培训课程中保存状态



我试图从培训中保存最新的LSTM状态,以便在预测阶段重复使用。我遇到的问题是,在TF LSTM模型中,状态是通过占位符和Numpy数组的组合从一个训练迭代到下一步传递的 - 当会话时,默认情况下似乎都不包含在图中保存。

为了解决此问题,我正在创建一个专用的TF变量来保存该状态的最新版本以将其添加到会话图中,例如:

# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')

这似乎可以很好地将savedState变量添加到保存的会话图中,并且以后在会话的其余部分中易于恢复。

问题是,我设法在恢复会话中实际使用该变量的唯一方法是,如果我在恢复后初始化会话中的所有变量(似乎将所有训练的变量重置,包括重量/偏见/等!(。如果我首先初始化变量,然后恢复会话(在保存训练有素的Varialbes方面工作正常(,那么我正在遇到一个错误,我正在尝试访问一个非初始化的变量。

我知道有一种方法可以初始化特定的个体杂质(我最初保存时正在使用它(,但是问题是,当我们恢复它们时,我们将它们称为字符串,我们不仅通过变量本身?!

# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')

完成此操作的正确方法是什么?作为解决方法,我目前正在将状态保存到CSV,就像一个Numpy数组一样,然后以相同的方式恢复。它可以正常工作,但显然不是最清洁的解决方案,因为保存/还原TF会话的其他方面都可以很好地工作。

任何建议都赞赏!

**编辑:如下所述,这是效果很好的代码:

# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')

# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])

我还没有测试过这种方法是否无意中初始化了保存的会话中的任何其他变量,但是不明白为什么应该运行特定的变量。

问题是在构造Saver之后创建新的tf.Variable意味着Saver不知道新变量。它仍然可以保存在Metagraph中,但没有保存在检查点:

import tensorflow as tf
with tf.Graph().as_default():
  var_a = tf.get_variable("a", shape=[])
  saver = tf.train.Saver()
  var_b = tf.get_variable("b", shape=[])
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  initializer = tf.global_variables_initializer()
  with tf.Session() as session:
    session.run([initializer])
    saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
  new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  with tf.Session() as session:
    new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!

我已经用Saver知道的变量来注释您上面的问题的快速复制。

现在,解决方案相对容易。我建议在Saver之前创建Variable,然后使用TF.Assign更新其值(请确保您 run tf.assign返回的OP(。分配的值将保存在检查点中,并像其他变量一样恢复。

None传递给其var_list构造函数参数时,Saver可以更好地处理这一点(即它可以自动拾取新变量(。请随时在GitHub上打开功能请求。

最新更新