我编写的代码灵感来自https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BasicDecoder.
在翻译/生成过程中,我们实例化了一个BasicDecoder
decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell,
sampler=greedy_sampler, output_layer=decoder.fc)
并使用以下args 调用此解码器
outputs, _, _ = decoder_instance(decoder_embedding_matrix,
start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)
什么应该是start_tokens和end_token,它们代表什么?BaseDecoder
签名中的一个例子给出了:
Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
>>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
>>> decoder = tfa.seq2seq.BasicDecoder(
... decoder_cell, sampler, output_layer, maximum_iterations=10)
>>>
>>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
>>> start_tokens = tf.fill([batch_size], 1)
>>> end_token = 2
>>>
>>> output, state, lengths = decoder(
... None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
>>>
>>> output.sample_id.shape
TensorShape([4, 10])
对于翻译任务,它们是
start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
end_token = targ_lang.word_index['<end>']
在我的应用程序中,字符的输入链具有形式
next_char = tf.constant(['Glücklicherweise '])
input_chars = tf.strings.unicode_split(next_char, 'UTF-8')
input_ids = ids_from_chars(input_chars).to_tensor()
对模型进行训练以生成下一个令牌。生成器应该生成"lücklicherweise x"
,其中x代表最可能的(或更详细的搜索(下一个字符。
我认为有多种方法可以很容易地理解!
输入:
### Method 1: As string
index = 0
next_char = tf.strings.substr(
input_word, index, len(input_word[0].numpy()) - index, unit="UTF8_CHAR", name=None
)
end_token = len(input_word[0].numpy())
print('next_char[0].numpy(): ' + str(next_char[0].numpy()))
def f1():
global pointer
print(input_word[pointer])
pointer = tf.add(pointer, 1)
return
def f2():
global my_index
print('add 1')
my_index = tf.add(my_index, 1)
return
r = tf.cond( tf.less_equal(my_index, pointer), f1, f2 )
输出:
### Method 1: As string
input_word[0].numpy() length: tf.Tensor([b'Glxc3xbccklicherweise '], shape=(1,), dtype=string)
input_word[0].numpy() length: 18
next_char[0].numpy(): b'Glxc3xbccklicherweise '
next_char[0].numpy(): b'lxc3xbccklicherweise '
next_char[0].numpy(): b'xc3xbccklicherweise '
next_char[0].numpy(): b'cklicherweise '
next_char[0].numpy(): b'klicherweise '
next_char[0].numpy(): b'licherweise '
...
### Method 2: As alphabets
tf.Tensor([b'G'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'xc3xbc'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'k'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'i'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'h'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)
tf.Tensor([b'r'], shape=(1,), dtype=string)
tf.Tensor([b'w'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)