节点'Merge/MergeSummary'具有来自不同帧的输入:这意味着什么?



试图合并我所有的摘要,我有一个错误,说Merge/MergeSummary的输入来自不同的帧。那么,首先:什么是框架?你能在 TF 文档中指出我关于这些东西的某个地方吗?--当然,我用谷歌搜索了一下,但几乎找不到任何东西。如何解决此问题?以重现错误的代码下方。提前谢谢。

import numpy as np
import tensorflow as tf
tf.reset_default_graph()
tf.set_random_seed(23)
BATCH = 2
LENGTH = 4
SIZE = 5
ATT_SIZE = 3
NUM_QUERIES = 2
def linear(inputs, output_size, use_bias=True, activation_fn=None):
    """Linear projection."""
    input_shape = inputs.get_shape().as_list()
    input_size = input_shape[-1]
    output_shape = input_shape[:-1] + [output_size]
    if len(output_shape) > 2:
        output_shape_tensor = tf.unstack(tf.shape(inputs))
        output_shape_tensor[-1] = output_size
        output_shape_tensor = tf.stack(output_shape_tensor)
        inputs = tf.reshape(inputs, [-1, input_size])
    kernel = tf.get_variable("kernel", [input_size, output_size])
    output = tf.matmul(inputs, kernel)
    if use_bias:
        output = output + tf.get_variable('bias', [output_size])
    if len(output_shape) > 2:
        output = tf.reshape(output, output_shape_tensor)
        output.set_shape(output_shape)  # pylint: disable=I0011,E1101
    if activation_fn is not None:
        return activation_fn(output)
    return output

class Attention(object):
    """Attention mechanism implementation."""
    def __init__(self, attention_states, attention_size):
        """Initializes a new instance of the Attention class."""
        self._states = attention_states
        self._attention_size = attention_size
        self._batch = tf.shape(self._states)[0]
        self._length = tf.shape(self._states)[1]
        self._size = self._states.get_shape()[2].value
        self._features = None
    def _init_features(self):
        states = tf.reshape(
            self._states, [self._batch, self._length, 1, self._size])
        weights = tf.get_variable(
            "kernel", [1, 1, self._size, self._attention_size])
        self._features = tf.nn.conv2d(states, weights, [1, 1, 1, 1], "SAME")
    def get_weights(self, query, scope=None):
        """Reurns the attention weights for the given query."""
        with tf.variable_scope(scope or "Attention"):
            if self._features is None:
                self._init_features()
            else:
                tf.get_variable_scope().reuse_variables()
            vect = tf.get_variable("Vector", [self._attention_size])
            with tf.variable_scope("Query"):
                query_features = linear(query, self._attention_size, False)
                query_features = tf.reshape(
                    query_features, [-1, 1, 1, self._attention_size])
        activations = vect * tf.tanh(self._features + query_features)
        activations = tf.reduce_sum(activations, [2, 3])
        with tf.name_scope('summaries'):
            tf.summary.histogram('histogram', activations)
        return tf.nn.softmax(activations)
states = tf.placeholder(tf.float32, shape=[BATCH, None, SIZE])  # unknown length
queries = tf.placeholder(tf.float32, shape=[NUM_QUERIES, BATCH, ATT_SIZE])
attention = Attention(states, ATT_SIZE)
func = lambda x: attention.get_weights(x, "Softmax")
weights = tf.map_fn(func, queries)
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    name = var.name.replace(':', '_')
    tf.summary.histogram(name, var)
summary_op = tf.summary.merge_all()
states_np = np.random.rand(BATCH, LENGTH, SIZE)
queries_np = np.random.rand(NUM_QUERIES, BATCH, ATT_SIZE)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    weights_np, summary_str = sess.run([weights, summary_op], {states: states_np, queries: queries_np})
    print weights_np

错误消息确实不是用户友好的。它已更新为

ValueError: Cannot use 'map/while/summaries/histogram' as input to 'Merge/MergeSummary' because 'map/while/summaries/histogram' is in a while loop. See info log for more details.

正如新消息所说,问题是您无法从 while 循环内部生成摘要。原始消息所引用的frame是 while 循环的"执行帧" - while 循环每次迭代的所有状态都保存在一个frame中。

在这种情况下,while_looptf.map_fn创建,其中的摘要tf.summary.histogram('histogram', activations)

有几种方法可以解决这个问题。您可以将摘要从get_weights中取出,也可以get_weights返回激活,使用从tf.map_fn调用中新返回的激活来创建摘要。

另一种方法是,如果NUM_QUERIES是恒定且较小的,则可以静态展开循环,而不是使用 tf.map_fn 。以下是执行此操作的代码:

# TOP PART OF THE CODE IS THE SAME
states = tf.placeholder(tf.float32, shape=[BATCH, None, SIZE])  # unknown length
queries = tf.placeholder(tf.float32, shape=[NUM_QUERIES, BATCH, ATT_SIZE])
attention = Attention(states, ATT_SIZE)
func = lambda x: attention.get_weights(x, "Softmax")
# NEW CODE BEGIN
split_queries = tf.split(queries, NUM_QUERIES)
weights = []
for query in split_queries:
    weights.append(func(query))
# NEW CODE END
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    name = var.name.replace(':', '_')
    tf.summary.histogram(name, var)
summary_op = tf.summary.merge_all()
states_np = np.random.rand(BATCH, LENGTH, SIZE)
queries_np = np.random.rand(NUM_QUERIES, BATCH, ATT_SIZE)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # NEW CODE BEGIN
    results = sess.run(weights + [summary_op], {states: states_np, queries: queries_np})
    weights_np, summary_str = results[:-1], results[-1]
    # NEW CODE END
    print weights_np

最新更新