TensorFlow Keras SavedModel在保存和加载两次后抛出TypeError &g



当我用一个或多个自定义层创建Keras模型时,我可以使用model.save()方法使用TensorFlow SavedModel格式来持久化Keras模型。

我可以使用tf.keras.models.load_model()函数从文件系统中加载该模型,并将其再次保存到文件系统中。

但是当我第二次从文件系统加载SavedModel时,它失败了,出现了以下异常:

TypeError: f(inputs, training, training, training, training, *, training, training) missing 1 required argument: training

您可以尝试使用以下代码复制此问题:

import tensorflow as tf
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs
model1 = tf.keras.Sequential([
CustomLayer()
])
model1.build((None, 1))
model1.compile()
model1.save("model1")
model2 = tf.keras.models.load_model("model1")
model2.save("model2")
# This line should raise a TypeError.
model3 = tf.keras.models.load_model("model2")

问题存在的原因

问题是TensorFlow SavedModel格式实际上并不序列化自定义Python代码。它只保存由自定义Keras层和其他Python对象生成的TensorFlow图。

默认情况下,tf.keras.models.load_model()函数不返回Python层。相反,它返回一个占位符层,其中包含TensorFlow计算图的相同部分。我们可以在我的问题中的例子中看到这一点:

>>> model1.layers
[<__main__.CustomLayer at 0x7ff04c14ee20>]
>>> model2.layers
[<keras.saving.saved_model.load.CustomLayer at 0x7ff114fd7be0>]

保存并加载model2时,TensorFlow不能正确解析CustomLayer.call()中的*args**kwargs参数。

我不知道实际的bug是在保存代码中,还是在加载代码中,还是两者都有。

真正的修复需要在TensorFlow/Keras中发生,但与此同时,有

<标题>工作区h1> 可以选择以下任何一种解决方法来避免自定义Keras层的序列化错误。

修改Layer.call()上的签名

目前Layer.call()上的官方方法签名是def call(self, inputs, *args, **kwargs):

但是当TensorFlow尝试加载带有此签名的自定义层的模型时,会抛出TypeError。为了修复这个错误,用def call(self, inputs):的签名来编写所有的自定义层。如果你的层在训练或推理期间表现不同,那么你可以使用方法签名def call(self, inputs, training=None):

这使得TensorFlow更容易生成在keras.saving.saved_model.load模块中生成的占位符层。但是这个占位符层仍然与原始Python代码不完全相同。

tf.keras.models.load_model()

上使用custom_objects

参数可以用原始的Python层而不是占位符层加载模型。只需将一个映射层名的字典传递给Python层类对象。这要求您的代码能够导入原始的Python层。在我的问题的例子可以固定如下:

model3 = tf.keras.models.load_model(
"model2",
custom_objects=dict(
CustomLayer=CustomLayer,
),
)

确保你的图层实现了Layer.get_config(),并返回一个字典,其中包含从头创建图层所需的所有参数。该图层必须能够用Layer.from_config()重新创建。

导入Python层并将其添加到Keras的全局注册表

Keras维护一个自定义Python类和其他对象的全局注册表,以便在加载SavedModels时引用。你可以用@tf.keras.utils.register_keras_serializable()装饰器注册你的自定义Keras层。例如:

@tf.keras.utils.register_keras_serializable(
package="my_python_package"
)
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs

这个方法也要求你的层正确实现Layer.get_config()

使用tf.keras.utils.custom_object_scope()安装Python层对象

与上面两个解决方案非常相似,tf.keras.utils.custom_object_scope()上下文管理器可以指定反序列化时使用哪些自定义层。

相关内容

  • 没有找到相关文章

最新更新