在张量流中训练单级 LSTM 时出错



所以我一直在尝试在tensorflow中训练一个单层编码器-解码器网络,鉴于文档的解释如此稀疏,这简直太令人沮丧了,而且我只在tensorflow上学习了斯坦福大学的CS231n。

所以这是一个简单的模型:

def simple_model(X,Y, is_training):
"""
a simple, single layered encoder decoder network, 
that encodes X of shape (batch_size, window_len, 
n_comp+1), then decodes Y of shape (batch_size, 
pred_len+1, n_comp+1), of which the vector Y[:,0,
:], is simply [0,...,0,1] * batch_size, so that 
it starts the decoding
"""
num_units = 128
window_len = X.shape[1]
n_comp = X.shape[2]-1
pred_len = Y.shape[1]-1
init = tf.contrib.layers.variance_scaling_initializer()
encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
encoder_output, encoder_state = tf.nn.dynamic_rnn(
encoder_cell,X,dtype = tf.float32)
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
decoder_output, _ = tf.nn.dynamic_rnn(decoder_cell,
encoder_output,
initial_state = encoder_state)
# we expect the shape to be of the shape of Y
print(decoder_output.shape)
proj_layer = tf.layers.dense(decoder_output, n_comp)
return proj_layer

现在我尝试设置训练详细信息:

tf.reset_default_graph()
X = tf.placeholder(tf.float32, [None, 15, 74])
y = tf.placeholder(tf.float32, [None, 4, 74])
is_training = tf.placeholder(tf.bool)
y_out = simple_model(X,y,is_training)
mean_loss = 0.5*tf.reduce_mean((y_out-y[:,1:,:-1])**2)
optimizer = tf.train.AdamOptimizer(learning_rate=5e-4)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
train_step = optimizer.minimize(mean_loss)

好的,那么现在我得到这个愚蠢的错误

值错误:变量 rnn/basic_lstm_cell/内核已存在,不允许。您的意思是设置重用=真还是重用=tf。AUTO_REUSE在VarScope中?最初定义为:

我不确定我是否正确理解了这一点。您的图形中有两个BasicLSTMCell。根据文档,您可能应该使用这样的MultiRNNCell

encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
rnn_layers = [encoder_cell, decoder_cell]
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
decoder_output, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=X,
dtype=tf.float32)

如果这不是您想要的正确架构,并且您需要分别使用两个BasicLSTMCell,我认为在定义encoder_cell时传递不同/唯一的namedecoder_cell将有助于解决此错误。tf.nn.dynamic_rnn会将单元格置于"RNN"范围内。如果未显式定义单元格名称,则会导致重用混淆。

最新更新