受监控的培训课程如何运作?



我试图了解使用tf.Sessiontf.train.MonitoredTrainingSession之间的区别,以及我可能更喜欢一个而不是另一个。似乎当我使用后者时,我可以避免许多"杂务",例如初始化变量、启动队列运行器或为摘要操作设置文件编写器。另一方面,对于受监控的培训课程,我无法指定要显式使用的计算图。这一切对我来说似乎相当神秘。这些类是如何创建的背后是否有一些我不理解的基本哲学?

我不能给出一些关于这些类是如何创建的见解,但这里有一些我认为与如何使用它们相关的内容。

tf.Session是python TensorFlow API中的一个低级对象,而, 正如您所说,tf.train.MonitoredTrainingSession具有许多方便的功能,在大多数常见情况下特别有用。

在描述tf.train.MonitoredTrainingSession的一些好处之前,让我回答有关会话使用的图形的问题。您可以使用上下文管理器with your_graph.as_default()指定MonitoredTrainingSession使用的tf.Graph

from __future__ import print_function
import tensorflow as tf
def example():
g1 = tf.Graph()
with g1.as_default():
# Define operations and tensors in `g`.
c1 = tf.constant(42)
assert c1.graph is g1
g2 = tf.Graph()
with g2.as_default():
# Define operations and tensors in `g`.
c2 = tf.constant(3.14)
assert c2.graph is g2
# MonitoredTrainingSession example
with g1.as_default():
with tf.train.MonitoredTrainingSession() as sess:
print(c1.eval(session=sess))
# Next line raises
# ValueError: Cannot use the given session to evaluate tensor:
# the tensor's graph is different from the session's graph.
try:
print(c2.eval(session=sess))
except ValueError as e:
print(e)
# Session example
with tf.Session(graph=g2) as sess:
print(c2.eval(session=sess))
# Next line raises
# ValueError: Cannot use the given session to evaluate tensor:
# the tensor's graph is different from the session's graph.
try:
print(c1.eval(session=sess))
except ValueError as e:
print(e)
if __name__ == '__main__':
example()

所以,正如你所说,使用MonitoredTrainingSession的好处是,这个对象可以照顾

  • 初始化变量,
  • 启动队列运行器以及
  • 设置文件编写器,

但它也有使代码易于分发的好处,因为它的工作方式也不同,具体取决于您是否将正在运行的进程指定为主进程。

例如,您可以运行如下内容:

def run_my_model(train_op, session_args):
with tf.train.MonitoredTrainingSession(**session_args) as sess:
sess.run(train_op)

您将以非分布式方式调用:

run_my_model(train_op, {})`

或以分布式方式(有关输入的更多信息,请参阅分布式文档):

run_my_model(train_op, {"master": server.target,
"is_chief": (FLAGS.task_index == 0)})

另一方面,使用原始tf.Session对象的好处是,您没有tf.train.MonitoredTrainingSession的额外好处,如果您不打算使用它们或想要获得更多控制(例如,关于队列的启动方式),这可能很有用。

编辑(根据评论):对于操作初始化,您必须执行以下操作(参见官方文档:

# Define your graph and your ops
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_p)
sess.run(your_graph_ops,...)

对于 QueueRunner,我会推荐您参考官方文档,您可以在其中找到更完整的示例。

编辑2:

要了解tf.train.MonitoredTrainingSession的工作原理,要了解的主要概念是_WrappedSession类:

此包装器用作各种会话包装器的基类 提供其他功能,例如监控、协调、 和恢复。

tf.train.MonitoredTrainingSession的工作方式(从版本 1.1 开始):

  • 它首先检查它是酋长还是工人(参见词汇问题的分布式文档)。
  • 它开始已经提供的钩子(例如,StopAtStepHook在此阶段只检索global_step张量。
  • 它创建一个会话,该会话是一个包装成_HookedSessionChief(或Worker会话),包装成包装在_RecoverableSession中的_CoordinatedSession
    Chief/Worker会话负责运行Scaffold提供的初始化操作。
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
    not specified a default one is created. It's used to finalize the graph.
    
  • chief会话还负责所有检查点部分:例如,使用ScaffoldSaver从检查点恢复。
  • _HookedSession基本上是为了装饰run方法:它调用_call_hook_before_run并在相关时after_run方法。
  • 在创建时,_CoordinatedSession会构建一个Coordinator,用于启动队列运行器并负责关闭它们。
  • _RecoverableSession将确保在tf.errors.AbortedError的情况下重试。

总之,tf.train.MonitoredTrainingSession避免了大量的样板代码,同时可以通过钩子机制轻松扩展。

最新更新