Tensorflow seq2seq 'feed_previous' argument'



我正在使用内置的tf.nn.seq2seq.embedding_attention_seq2seq()函数,我对feed_previous参数有一些问题,在训练期间,地面事实被馈送到解码器中,而在测试期间,我们将最后一个时间步的输出馈送到解码器。问题是,一旦我设置了feed_previous参数,我就无法更改该参数。我想在每个时期测试我的模型,我该怎么办?

从文档中,您可以为feed_previous提供一个布尔张量。

feed_previous = tf.placeholder(tf.bool)
model = tf.nn.seq2seq.embedding_attention_seq2seq(..feed_previous=feed_previous...)
sess.run(loss, feed_dict={feed_previous=is_training, ...})

相关内容

最新更新