我的理解是,tf.nn.dynamic_rnn
在每个时间步返回RNN单元(例如LSTM(的输出以及最终状态。如何访问所有时间步长的单元格状态,而不仅仅是最后一个时间步长?例如,我希望能够平均所有隐藏状态,然后在后续层中使用它。
以下是我如何定义 LSTM 单元格,然后使用 tf.nn.dynamic_rnn
展开它。但这只给出了 LSTM 的最后一个单元状态。
import tensorflow as tf
import numpy as np
# [batch-size, sequence-length, dimensions]
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]
cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)
outputs, last_state = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
out, last = sess.run([outputs, last_state], feed_dict=None)
这样的事情应该有效。
import tensorflow as tf
import numpy as np
class CustomRNN(tf.contrib.rnn.LSTMCell):
def __init__(self, *args, **kwargs):
kwargs['state_is_tuple'] = False # force the use of a concatenated state.
returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
self._output_size = self._state_size # change the output size to the state size
return returns
def __call__(self, inputs, state):
output, next_state = super(CustomRNN, self).__call__(inputs, state)
return next_state, next_state # return two copies of the state, instead of the output and the state
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 10]
cell = CustomRNN(num_units=64)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
states, last_state = sess.run([outputs, last_states], feed_dict=None)
这使用串联状态,因为我不知道您是否可以存储任意数量的元组状态。状态变量的形状为 (batch_size、max_time_size、state_size(。
我会指出这个线程(我的亮点(:
您可以编写 LSTMCell 的变体,如果每个时间步长都需要 c 和 h 状态,则返回两个状态张量作为输出的一部分。如果你只需要 h 状态,那就是每个时间步长的输出。
正如@jasekp在其评论中所写,输出实际上是状态h
部分。然后,dynamic_rnn
方法将跨时间堆叠所有h
部分(请参阅此文件中_dynamic_rnn_loop
的字符串文档(:
def _dynamic_rnn_loop(cell,
inputs,
initial_state,
parallel_iterations,
swap_memory,
sequence_length=None,
dtype=None):
"""Internal implementation of Dynamic RNN.
[...]
Returns:
Tuple `(final_outputs, final_state)`.
final_outputs:
A `Tensor` of shape `[time, batch_size, cell.output_size]`. If
`cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
objects, then this returns a (possibly nsted) tuple of Tensors matching
the corresponding shapes.