在同一个 Tensorflow 会话中从 Saver 加载两个模型



我有两个网络:一个生成输出的Model和一个对输出进行分级的Adversary

两者都是单独训练的,但现在我需要在单个会话中组合它们的输出。

我试图实现这篇文章中提出的解决方案:同时运行多个预先训练的Tensorflow网络

我的代码

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)
#...
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)
    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

问题所在

对函数model_saver.restore()调用似乎什么也没做。在另一个模块中,我使用带有tf.train.Saver(tf.global_variables())的保护程序,它可以很好地恢复检查点。

该模型具有model.tvars = tf.trainable_variables() .为了检查发生了什么,我使用sess.run()在还原前后提取tvars。每次使用初始随机分配的变量并且未分配检查点中的变量时。

关于为什么model_saver.restore()似乎什么都不做的任何想法?

解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它。

为了诊断问题,我手动遍历每个变量并逐个分配它们。然后我注意到分配变量后名称会更改。这是描述如下的:TensorFlow检查点保存和读取

根据那篇文章中的建议,我在自己的图表中运行了每个模型。这也意味着我必须在自己的会话中运行每个图形。这意味着以不同的方式处理会话管理。

首先,我创建了两个图表

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)
adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

然后是两个会话

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

然后我初始化了每个会话中的变量并分别恢复了每个图形

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)
with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

从这里开始,每当需要每个会话时,我都会用with sess.as_default():包装该会话中的任何tf函数。最后我手动关闭会话

sess.close()
adv_sess.close()

标记为正确的答案并没有告诉我们如何将两个不同的模型显式加载到一个会话中,这是我的答案:

  1. 为要加载的模型创建两个不同的名称范围。

  2. 初始化两个保存程序,它们将加载两个不同网络中变量的参数。

  3. 从相应的检查点文件加载。

with tf.Session() as sess:
    with tf.name_scope("net1"):
      net1 = Net1()
    with tf.name_scope("net2"):
      net2 = Net2()
    net1_varlist = {v.op.name.lstrip("net1/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
    net1_saver = tf.train.Saver(var_list=net1_varlist)
    net2_varlist = {v.op.name.lstrip("net2/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
    net2_saver = tf.train.Saver(var_list=net2_varlist)
    net1_saver.restore(sess, "net1.ckpt")
    net2_saver.restore(sess, "net2.ckpt")

请检查这个:

adv_varlist = {v.name.lstrip("avd/")[:-2]: v 

它应该是"adv",不是吗

最新更新