tf.contrib.seq2seq.BeamSearchDecoder
outputs
中内容的形状是什么。我知道它是class BeamSearchDecoderOutput(scores, predicted_ids, parent_ids)
的实例,但是scores
、predicted_ids
和parent_ids
的形状是什么?
我自己写了以下玩具代码来探索它。
tgt_vocab_size = 20
embedding_decoder = tf.one_hot(list(range(0, tgt_vocab_size)), tgt_vocab_size)
batch_size = 2
start_tokens = tf.fill([batch_size], 0)
end_token = 1
beam_width = 3
num_units=18
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
encoder_outputs = decoder_cell.zero_state(batch_size, dtype=tf.float32)
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell,
embedding=embedding_decoder,
start_tokens=start_tokens,
end_token=end_token,
initial_state=tiled_encoder_outputs,
beam_width=beam_width)
# dynamic decoding
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder,
maximum_iterations=4,
output_time_major=True)
final_predicted_ids = outputs.predicted_ids
scores = outputs.beam_search_decoder_output.scores
predicted_ids = outputs.beam_search_decoder_output.predicted_ids
parent_ids = outputs.beam_search_decoder_output.parent_ids
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
final_predicted_ids_vals = sess.run(final_predicted_ids)
print("final_predicted_ids shape:")
print(final_predicted_ids_vals.shape)
print("final_predicted_ids_vals: n%s" %final_predicted_ids_vals)
print("scores shape:")
print(sess.run(scores).shape)
print("scores values: n %s" % sess.run(scores))
print("predicted_ids shape: ")
print(sess.run(predicted_ids).shape)
print("predicted_ids values: n %s" % sess.run(predicted_ids))
print("parent_ids shape:")
print(sess.run(parent_ids).shape)
print("parent_ids values: n %s" % sess.run(parent_ids))
打印如下:
final_predicted_ids shape:
(4, 2, 3)
final_predicted_ids_vals:
[[[ 1 8 8]
[ 1 8 8]]
[[ 1 13 13]
[ 1 13 13]]
[[ 1 13 13]
[ 1 13 13]]
[[ 1 13 2]
[ 1 13 2]]]
scores shape:
(4, 2, 3)
scores values:
[[[ -2.8376358 -2.843168 -2.8478816]
[ -2.8376358 -2.843168 -2.8478816]]
[[ -2.8478816 -5.655898 -5.6810265]
[ -2.8478816 -5.655898 -5.6810265]]
[[ -2.8478816 -8.478384 -8.495466 ]
[ -2.8478816 -8.478384 -8.495466 ]]
[[ -2.8478816 -11.292251 -11.307263 ]
[ -2.8478816 -11.292251 -11.307263 ]]]
predicted_ids shape:
(4, 2, 3)
predicted_ids values:
[[[ 8 13 1]
[ 8 13 1]]
[[ 1 13 13]
[ 1 13 13]]
[[ 1 13 12]
[ 1 13 12]]
[[ 1 13 2]
[ 1 13 2]]]
parent_ids shape:
(4, 2, 3)
parent_ids values:
[[[0 0 0]
[0 0 0]]
[[2 0 1]
[2 0 1]]
[[0 1 1]
[0 1 1]]
[[0 1 1]
[0 1 1]]]
tf.contrib.seq2seq.dynamic_decode(BeamSearchDecoder)
的outputs
实际上是class FinalBeamSearchDecoderOutput
的一个实例,它包括:
predicted_ids
:所有解码完成后,波束搜索返回的最终输出。形状为[batch_size, num_steps, beam_width]
(如果output_time_major
True
则为[num_steps, batch_size, beam_width]
(。梁按从最佳到最差的顺序排列。
beam_search_decoder_output
:描述波束搜索状态的 BeamSearchDecoderOutput 实例。
因此,需要确保最终预测/翻译的形状[beam_width, batch_size, num_steps]
transpose([2, 0, 1])
或tf.transpose(final_predicted_ids)
(如果output_time_major=True
(。