张量流保存/恢复批量规范



我在Tensorflow中训练了一个具有批处理范数的模型。我想保存模型并恢复它以供进一步使用。批次规范由

def batch_norm(input, phase):
return tf.layers.batch_normalization(input, training=phase)

阶段在训练期间True,在测试期间False

似乎只是简单地打电话

saver = tf.train.Saver()
saver.save(sess, savedir + "ckpt")

不会很好地工作,因为当我恢复模型时,它首先说恢复成功。它还说Attempting to use uninitialized value batch_normalization_585/beta如果我只运行图形中的一个节点。这是否与未正确保存模型或我错过的其他内容有关?

我也有"尝试使用未初始化的值batch_normalization_585/beta"错误。这是因为通过像这样用空括号声明保存程序:

saver = tf.train.Saver() 

保护程序将保存 tf.trainable_variables(( 中包含的变量,这些变量不包含批量归一化的移动平均线。要将此变量包含在保存的 ckpt 中,您需要执行以下操作:

saver = tf.train.Saver(tf.global_variables())

这节省了所有变量,因此非常消耗内存。或者,您必须识别具有移动平均值或方差的变量,并通过声明它们来保存它们,如下所示:

saver = tf.train.Saver(tf.trainable_variables() + list_of_extra_variables)

不确定这是否需要解释,但以防万一(以及其他潜在的观众(。

每当您在 TensorFlow 中创建操作时,都会向图形中添加一个新节点。图形中没有两个节点可以具有相同的名称。您可以定义您创建的任何节点的名称,但如果您不给出名称,TensorFlow 将以确定性的方式为您选择一个(即,不是随机的,而是始终使用相同的顺序(。如果添加两个数字,则可能是Add,但是如果您进行另一个添加,因为没有两个节点可以具有相同的名称,因此它可能类似于Add_2.在图形中创建节点后,其名称无法更改。许多函数依次创建多个子节点;例如,tf.layers.batch_normalization创建一些内部变量betagamma

按以下方式保存和恢复作品:

  1. 创建表示所需模型的图形。此图包含将由保护程序保存的变量。
  2. 你对该图进行初始化、训练或做任何你想做的事情,模型中的变量就会被分配一些值。
  3. 在保护程序上调用save,将变量的值保存到文件中。
  4. 现在,您可以在不同的图形中重新创建模型(它可以是完全不同的Python会话,也可以是与第一个图形共存的另一个图形(。必须以与第一个模型完全相同的方式创建模型。
  5. 在保护程序上调用restore以检索变量的值。

为了使其正常工作,第一个和第二个图中变量的名称必须完全相同

在您的示例中,TensorFlow 抱怨变量batch_normalization_585/beta。似乎您在同一张图中调用了近 600 次tf.layers.batch_normalization,因此您有很多beta变量。我怀疑你是否真的需要那么多,所以我想你只是在试验 API,最终得到了那么多副本。

这是应该工作的内容的草稿:

import tensorflow as tf
def make_model():
input = tf.placeholder(...)
phase = tf.placeholder(...)
input_norm = tf.layers.batch_normalization(input, training=phase))
# Do some operations with input_norm
output = ...
saver = tf.train.Saver()
return input, output, phase, saver
# We work with one graph first
g1 = tf.Graph()
with g1.as_default():
input, output, phase, saver = make_model()
with tf.Session() as sess:
# Do your training or whatever...
saver.save(sess, savedir + "ckpt")
# We work with a second different graph now
g2 = tf.Graph()
with g2.as_default():
input, output, phase, saver = make_model()
with tf.Session() as sess:
saver.restore(sess, savedir + "ckpt")
# Continue using your model...

同样,典型的情况不是并排使用两个图,而是使用一个图,然后稍后在另一个 Python 会话中重新创建它,但最终两者是相同的。重要的部分是,在这两种情况下,模型的创建方式相同(因此具有相同的节点名称(。

相关内容

  • 没有找到相关文章

最新更新