试图合并我所有的摘要,我有一个错误,说Merge/MergeSummary
的输入来自不同的帧。那么,首先:什么是框架?你能在 TF 文档中指出我关于这些东西的某个地方吗?--当然,我用谷歌搜索了一下,但几乎找不到任何东西。如何解决此问题?以重现错误的代码下方。提前谢谢。
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
tf.set_random_seed(23)
BATCH = 2
LENGTH = 4
SIZE = 5
ATT_SIZE = 3
NUM_QUERIES = 2
def linear(inputs, output_size, use_bias=True, activation_fn=None):
"""Linear projection."""
input_shape = inputs.get_shape().as_list()
input_size = input_shape[-1]
output_shape = input_shape[:-1] + [output_size]
if len(output_shape) > 2:
output_shape_tensor = tf.unstack(tf.shape(inputs))
output_shape_tensor[-1] = output_size
output_shape_tensor = tf.stack(output_shape_tensor)
inputs = tf.reshape(inputs, [-1, input_size])
kernel = tf.get_variable("kernel", [input_size, output_size])
output = tf.matmul(inputs, kernel)
if use_bias:
output = output + tf.get_variable('bias', [output_size])
if len(output_shape) > 2:
output = tf.reshape(output, output_shape_tensor)
output.set_shape(output_shape) # pylint: disable=I0011,E1101
if activation_fn is not None:
return activation_fn(output)
return output
class Attention(object):
"""Attention mechanism implementation."""
def __init__(self, attention_states, attention_size):
"""Initializes a new instance of the Attention class."""
self._states = attention_states
self._attention_size = attention_size
self._batch = tf.shape(self._states)[0]
self._length = tf.shape(self._states)[1]
self._size = self._states.get_shape()[2].value
self._features = None
def _init_features(self):
states = tf.reshape(
self._states, [self._batch, self._length, 1, self._size])
weights = tf.get_variable(
"kernel", [1, 1, self._size, self._attention_size])
self._features = tf.nn.conv2d(states, weights, [1, 1, 1, 1], "SAME")
def get_weights(self, query, scope=None):
"""Reurns the attention weights for the given query."""
with tf.variable_scope(scope or "Attention"):
if self._features is None:
self._init_features()
else:
tf.get_variable_scope().reuse_variables()
vect = tf.get_variable("Vector", [self._attention_size])
with tf.variable_scope("Query"):
query_features = linear(query, self._attention_size, False)
query_features = tf.reshape(
query_features, [-1, 1, 1, self._attention_size])
activations = vect * tf.tanh(self._features + query_features)
activations = tf.reduce_sum(activations, [2, 3])
with tf.name_scope('summaries'):
tf.summary.histogram('histogram', activations)
return tf.nn.softmax(activations)
states = tf.placeholder(tf.float32, shape=[BATCH, None, SIZE]) # unknown length
queries = tf.placeholder(tf.float32, shape=[NUM_QUERIES, BATCH, ATT_SIZE])
attention = Attention(states, ATT_SIZE)
func = lambda x: attention.get_weights(x, "Softmax")
weights = tf.map_fn(func, queries)
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
name = var.name.replace(':', '_')
tf.summary.histogram(name, var)
summary_op = tf.summary.merge_all()
states_np = np.random.rand(BATCH, LENGTH, SIZE)
queries_np = np.random.rand(NUM_QUERIES, BATCH, ATT_SIZE)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
weights_np, summary_str = sess.run([weights, summary_op], {states: states_np, queries: queries_np})
print weights_np
错误消息确实不是用户友好的。它已更新为
ValueError: Cannot use 'map/while/summaries/histogram' as input to 'Merge/MergeSummary' because 'map/while/summaries/histogram' is in a while loop. See info log for more details.
正如新消息所说,问题是您无法从 while 循环内部生成摘要。原始消息所引用的frame
是 while 循环的"执行帧" - while 循环每次迭代的所有状态都保存在一个frame
中。
在这种情况下,while_loop
由tf.map_fn
创建,其中的摘要tf.summary.histogram('histogram', activations)
。
有几种方法可以解决这个问题。您可以将摘要从get_weights
中取出,也可以get_weights
返回激活,使用从tf.map_fn
调用中新返回的激活来创建摘要。
另一种方法是,如果NUM_QUERIES是恒定且较小的,则可以静态展开循环,而不是使用 tf.map_fn
。以下是执行此操作的代码:
# TOP PART OF THE CODE IS THE SAME
states = tf.placeholder(tf.float32, shape=[BATCH, None, SIZE]) # unknown length
queries = tf.placeholder(tf.float32, shape=[NUM_QUERIES, BATCH, ATT_SIZE])
attention = Attention(states, ATT_SIZE)
func = lambda x: attention.get_weights(x, "Softmax")
# NEW CODE BEGIN
split_queries = tf.split(queries, NUM_QUERIES)
weights = []
for query in split_queries:
weights.append(func(query))
# NEW CODE END
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
name = var.name.replace(':', '_')
tf.summary.histogram(name, var)
summary_op = tf.summary.merge_all()
states_np = np.random.rand(BATCH, LENGTH, SIZE)
queries_np = np.random.rand(NUM_QUERIES, BATCH, ATT_SIZE)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# NEW CODE BEGIN
results = sess.run(weights + [summary_op], {states: states_np, queries: queries_np})
weights_np, summary_str = results[:-1], results[-1]
# NEW CODE END
print weights_np