如何实现对keras的字束搜索ctc



我正在构建一个手写识别模型,该模型目前具有88%的验证准确率。我看到了这个github页面,它可以帮助模型使用字典实现更准确的预测。

问题是我不知道如何在我当前的模型中实现这一点。这是我当前的ctc函数,它是从keras教程中复制的。如何修改此项以添加词典?

class CTCLayer(keras.layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost
def call(self, y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
# At test time, just return the computed predictions.
return y_pred

这是在github原始页面上实现的字束搜索。具体来说,我的主要问题是从函数中获得损失。稍后将损失返还给ctc层。


chars = ''.join(self.char_list)
word_chars = open('../model/wordCharList.txt').read().splitlines()[0]
corpus = open('../data/corpus.txt').read()
# decode using the "Words" mode of word beam search
from word_beam_search import WordBeamSearch
self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'),word_chars.encode('utf8'))

我试着查看在他们的项目中实现这一点的github页面,但他们似乎使用tensorflow v1,这对我来说有点困惑,因为我是这个领域的初学者。如有任何回应,不胜感激。

字波束搜索只是一个解码器,不是一个损失函数。对于损失,您仍然使用";标准";与Keras一起运送的CTC损失。这意味着在你的训练代码中,你甚至不必考虑单词波束搜索。

字束搜索仅用于推理。您所要做的就是将张量转换为numpy数组,有关详细信息,请参阅文档。

相关内容

  • 没有找到相关文章

最新更新