从Keras中的h5文件加载自定义CTC层



我有一个CTCLayer类,如下所示:

class CTCLayer(layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost

def call(self, y_true, y_pred):
# Compute the training-time loss value and add it
# to the layer using `self.add_loss()`.
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
# At test time, just return the computed predictions
return y_pred

我训练了我的模型,将其保存到model.h5文件中,并通过加载

model_load = tf.keras.models.load_model('model.h5', custom_objects={'CTCLayer': CTCLayer})

它抛出了一个init((得到了一个意外的关键字参数"trainable">错误。

由于我不想再次训练我的模型(时间约束(,有没有什么变通方法可以在不必在CTCLayer类中添加get_config((的情况下加载模型?

如果没有,我应该如何修改类中的get_config((?

这应该有效:

class CTCLayer(layers.Layer):
def __init__(self, name=None):
def __init__(self, name=None, **kwargs):
super(CTCLayer, self).__init__(name=name, **kwargs)

最新更新