永久将常数注入TensorFlow图,以进行推理



我用占位符的is_training训练模型:

is_training_ph = tf.placeholder(tf.bool)

但是,一旦完成训练和验证,我想永久将false的常数注入此值中,然后"重新优化"该图(即使用optimize_for_inference)。freeze_graph的线路是否会这样做?

一种可能性是使用tf.import_graph_def()函数及其input_map参数来重写图中该张量的值。例如,您可以按以下方式构建程序:

with tf.Graph().as_default() as training_graph:
  # Build model.
  is_training_ph = tf.placeholder(tf.bool, name="is_training")
  # ...
training_graph_def = training_graph.as_graph_def()
with tf.Graph().as_default() as temp_graph:
  tf.import_graph_def(training_graph_def,
                      input_map={is_training_ph.name: tf.constant(False)})
temp_graph_def = temp_graph.as_graph_def()

构建temp_graph_def后,您可以将其用作freeze_graph的输入。


一种替代方案,它可能与freeze_graphoptimize_for_inference脚本更兼容(对可变名称和检查点键的假设)是修改TensorFlow的graph_util.convert_variables_to_constants()函数,以使其转换为占位符:

def convert_placeholders_to_constants(input_graph_def,
                                      placeholder_to_value_map):
  """Replaces placeholders in the given tf.GraphDef with constant values.
  Args:
    input_graph_def: GraphDef object holding the network.
    placeholder_to_value_map: A map from the names of placeholder tensors in
      `input_graph_def` to constant values.
  Returns:
    GraphDef containing a simplified version of the original.
  """
  output_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    output_node = tf.NodeDef()
    if node.op == "Placeholder" and node.name in placeholder_to_value_map:
      output_node.op = "Const"
      output_node.name = node.name
      dtype = node.attr["dtype"].type
      data = np.asarray(placeholder_to_value_map[node.name],
                        dtype=tf.as_dtype(dtype).as_numpy_dtype)
      output_node.attr["dtype"].type = dtype
      output_node.attr["value"].CopyFrom(tf.AttrValue(
          tensor=tf.contrib.util.make_tensor_proto(data,
                                                   dtype=dtype,
                                                   shape=data.shape)))
    else:
      output_node.CopyFrom(node)
    output_graph_def.node.extend([output_node])
  return output_graph_def

...然后您可以如上所述构建training_graph_def,然后写:

temp_graph_def = convert_placeholders_to_constants(training_graph_def,
                                                   {is_training_ph.op.name: False})

最新更新