TensorFlow attention_decoder with RNNCell (state_is_tuple=Tr



我想用一个attention_decoder构建一个seq2seq模型,并使用MultiRNNCell与LSTMCell作为编码器。因为TensorFlow代码表明"这种默认行为(state_is_tuple=False)将很快被弃用",所以我为编码器设置了state_is_tuple=True。

问题是,当我将编码器的状态传递给attention_decoder时,它会报告一个错误:

*** AttributeError: 'LSTMStateTuple' object has no attribute 'get_shape'

这个问题似乎与seq2seq.py中的attention()函数和rnn_cell.py中的_linear()函数有关,其中代码从编码器生成的initial_state中调用'LSTMStateTuple'对象的'get_shape()'函数。

虽然当我为编码器设置state_is_tuple=False时错误消失,但程序给出以下警告:

WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.LSTMCell object at 0x11763dc50>: Using a concatenated state is slower and will soon be deprecated.  Use state_is_tuple=True.

我真的很感激,如果有人能给任何指令建立seq2seq与RNNCell (state_is_tuple=True)。

我也遇到了这个问题,lstm状态需要连接起来,否则_linear会抱怨。LSTMStateTuple的形状取决于你使用的电池的种类。对于LSTM单元,您可以像这样连接状态:

 query = tf.concat(1,[state[0], state[1]])

如果你使用的是MultiRNNCell,请先连接每层的状态:

 concat_layers = [tf.concat(1,[c,h]) for c,h in state]
 query = tf.concat(1, concat_layers)

相关内容

最新更新