我使用TensorFlow编写了一个RNN语言模型。该模型被实现为一个RNN
类。图结构在构造函数中构建,RNN.train
和RNN.test
方法运行。
当我移动到训练集中的新文档时,或者当我想在训练期间运行验证集时,我希望能够重置RNN状态。我通过管理训练循环中的状态来做到这一点,通过提要字典将其传递到图中。
在构造函数中,我像这样定义RNN
cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
initial_state=self.state)
训练循环看起来像这样
for document in document:
state = session.run(self.reset_state)
for x, y in document:
_, state = session.run([self.train_step, self.next_state],
feed_dict={self.x:x, self.y:y, self.state:state})
x
和y
为文档中的批次训练数据。这个想法是,我在每个批处理之后传递最新的状态,除了当我开始一个新文档时,当我通过运行self.reset_state
将状态归零时。
这一切都有效。现在我想改变我的RNN使用推荐的state_is_tuple=True
。但是,我不知道如何通过提要字典传递更复杂的LSTM状态对象。我也不知道在构造函数中传递什么参数给self.state = tf.placeholder(...)
行。
正确的策略是什么?dynamic_rnn
仍然没有太多的示例代码或文档可用。
TensorFlow问题2695和2838似乎相关。
一篇关于WILDML的博客文章解决了这些问题,但没有直接给出答案。
参见TensorFlow:记住下一批LSTM的状态(有状态LSTM)
Tensorflow占位符的一个问题是,你只能用Python列表或Numpy数组(我认为)来提供它。所以你不能在LSTMStateTuple的元组中保存运行之间的状态。
我通过将状态保存在像这样的张量中来解决这个问题
initial_state = np.zeros((num_layers, 2, batch_size, state_size))
在构建图时,像这样解包并创建元组状态:
state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
for idx in range(num_layers)]
)
然后你就得到了新的状态
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)
不应该是这样的…也许他们正在研究解决方案。
在RNN状态中提供的一种简单方法是简单地分别提供状态元组的两个组件。
# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
rnn_cell,
self.input,
initial_state=self.state)
# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
self.input: input
})
# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
self.input: input,
self.state[0]: state[0],
self.state[1]: state[1]
})