我想创建一个双向RNN编码器embedding_attention_seq2seq在 seq2seq_model.py:这是代码爆炸
def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
num_encoder_symbols, num_decoder_symbols,
num_heads=1, output_projection=None,
feed_previous=False, dtype=dtypes.float32,
scope=None, initial_state_attention=False):
with variable_scope.variable_scope(scope or"embedding_attention_seq2seq"):
# Encoder.
encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
encoder_outputs, encoder_state = rnn.rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
top_states = [array_ops.reshape(e, [-1, 1, cell.output_size])
for e in encoder_outputs]
attention_states = array_ops.concat(1, top_states)
....
在这里我更改了代码。
# Encoder.
encoder_cell = copy.deepcopy(cell)
encoder_cell = core_rnn_cell.EmbeddingWrapper(
encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
(encoder_outputs,
encoder_fw_final_state,
encoder_bw_final_state)=rnn.bidirectional_rnn(
cell_fw=encoder_cell,
cell_bw=encoder_cell,
encoder_inputs,
dtype=dtype)
encoder_final_state_c = tf.concat(
(encoder_fw_final_state.c, encoder_bw_final_state.c), 1)
encoder_final_state_h = tf.concat(
(encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
encoder_state = LSTMStateTuple(
c=encoder_final_state_c,
h=encoder_final_state_h)
错误列表:
Traceback (most recent call last):
File "translate.py", line 301, in <module>
tf.app.run()
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/platform/app.py", line 43, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "translate.py", line 297, in main
train()
File "translate.py", line 156, in train
model = create_model(sess, False)
File "translate.py", line 134, in create_model
dtype=dtype)
File "/home/tensorflow/Downloads/NMT-jp-ch-master/seq2seq_model.py", line 185, in __init__
softmax_loss_function=softmax_loss_function)
File "/home/tensorflow/Downloads/NMT-jp-ch-master/seq2seq.py", line 628, in model_with_buckets
decoder_inputs[:bucket[1]])
File "/home/tensorflow/Downloads/NMT-jp-ch-master/seq2seq_model.py", line 184, in <lambda>
lambda x, y: seq2seq_f(x, y, False),
File "/home/tensorflow/Downloads/NMT-jp-ch-master/seq2seq_model.py", line 148, in seq2seq_f
dtype=dtype)
File "/home/tensorflow/Downloads/NMT-jp-ch-master/seq2seq.py", line 432, in embedding_attention_seq2seq
inputs=encoder_inputs
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/ops/rnn.py", line 652, in bidirectional_dynamic_rnn
time_major=time_major, scope=fw_scope)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/ops/rnn.py", line 789, in dynamic_rnn
for input_ in flat_input)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/ops/rnn.py", line 789, in <genexpr>
for input_ in flat_input)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/ops/array_ops.py", line 1280, in transpose
ret = gen_array_ops.transpose(a, perm, name=name)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3656, in transpose
result = _op_def_lib.apply_op("Transpose", x=x, perm=perm, name=name)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/framework/op_def_library.py", line 759, in apply_op
op_def=op_def)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 2242, in create_op
set_shapes_for_outputs(ret)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 1617, in set_shapes_for_outputs
shapes = shape_func(op)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 1568, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
debug_python_shape_fn, require_shape_fn)
File "/home/tensorflow/anaconda3/envs/tf/lib/python3.4/site-packages/tensorflow/python/framework/common_shapes.py", line 675, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Dimension must be 1 but is 3 for 'model_with_buckets/embedding_attention_seq2seq/BiRNN/FW/transpose' (op: 'Transpose') with input shapes: [?], [3].
我使用py3.4和tf-v0.12
如何使用https://github.com/tensorflow/tensorflow/tensorflow/blob/master/master/tensorflow/tensorflow/contrib/contrib/contrib/rnn/python/python/python/python/rnn.py创建正确的双向RNN编码器。
预先感谢您。
top_states = [array_ops.reshape(e,[-1,1,cell.output_size*2])
是的,重塑应该 *2。