Tensorflow addons seq2seq output of BasicDecoder call (tfa.s



基于 tfa.seq2seq 构建一个 seq2seq,基本上就像在 https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt#train_the_model 中一样工作。我在调用BasicDecoder时正在查看输出的性质。我创建了一个解码器实例

decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, 
sampler=greedy_sampler, output_layer=decoder.fc)

后来叫它

outputs, _, _ = decoder_instance(decoder_embedding_matrix,  
start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)

这里outputs什么:预测概率?

接下来我想做这样的事情

predicted_logits = predicted_logits[:, -1, :]
predicted_logits = predicted_logits/temperature

# Sample the output logits to generate token IDs.
predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids = tf.squeeze(predicted_ids, axis=-1)

# Convert from token ids to characters
predicted_chars = chars_from_ids(predicted_ids)

编辑

在我的测试中outputs看起来像这样

BasicDecoderOutput(rnn_output=<tf.Tensor: shape=(1, 1, 106), dtype=float32, numpy=
array([[[-1.7647576 ,  1.2142688 ,  2.3475904 ,  0.35890207,
0.72230023, -0.3587367 , -0.02984604, -1.9962349 ,
0.510706  , -1.4457364 , -0.43458703, -0.55248725,
-0.9126631 , -0.5542034 , -1.2392808 , -1.0972862 ,
-0.7256295 ,  0.02101   , -1.0858598 ,  0.9452345 ,
0.56474745,  0.2157154 ,  1.6094822 ,  0.6396736 ,
1.5741622 ,  1.4455014 ,  0.9529134 ,  0.37970737,
-0.60284877,  0.73455685,  1.0571934 ,  1.3716137 ,
-1.0882497 ,  1.7738185 ,  1.1919689 ,  0.8144775 ,
0.84732264,  1.6677057 ,  1.8040668 ,  0.86257285,
2.0206916 ,  1.3602887 ,  1.2091455 ,  1.318665  ,
-0.6775206 , -0.9906771 , -0.39923188, -1.0290842 ,
-1.3546644 , -1.5678416 ,  0.624691  , -1.0316744 ,
1.2098004 ,  1.4669724 ,  0.9996722 ,  0.12806134,
-0.42086226, -0.11248919, -0.8277442 ,  0.622267  ,
-1.6404072 ,  0.2762841 , -0.54035664, -0.6325757 ,
-0.16794772,  0.8435169 ,  1.1214966 , -1.5629222 ,
0.27472585,  0.8861834 , -1.7886144 ,  0.56741697,
-1.9197755 , -1.8073375 , -1.5050163 , -1.7794812 ,
-0.11308812,  1.3161705 ,  1.027235  ,  1.3830551 ,
-1.374056  , -1.4779223 ,  0.19962706, -1.6843308 ,
0.370475  ,  0.8292502 , -1.2990475 , -1.8491654 ,
-3.4606798 , -0.9822829 , -2.391135  , -3.6944065 ,
-3.5912528 , -2.4165688 , -2.640759  , -4.0524964 ,
-3.0878603 , -1.6555822 , -1.2015637 , -1.7716323 ,
1.7384199 , -2.4340994 , -0.7337967 , -0.88279086,
-0.85630864, -0.8148002 ]]], dtype=float32)>, sample_id=<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[2]], dtype=int32)>)

使用class GreedyEmbeddingSampler(Sampler):进行推理 https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/seq2seq/sampler.py#L559-L650

def sample(self, time, outputs, state):
"""sample for GreedyEmbeddingHelper."""
del time, state  # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, tf.Tensor):
raise TypeError(
"Expected outputs to be a single Tensor, got: %s" % type(outputs)
)
sample_ids = tf.argmax(outputs, axis=-1, output_type=tf.int32)
return sample_ids

这样# Outputs are logits, use argmax to get the most probable id

BasicDecoder 返回outputs = BasicDecoderOutput(cell_outputs, sample_ids)它们是 RNN 单元或最终的密集层输出以及 logits 的 argmax 的 id。

最新更新