当我用一个或多个自定义层创建Keras模型时,我可以使用model.save()
方法使用TensorFlow SavedModel格式来持久化Keras模型。
我可以使用tf.keras.models.load_model()
函数从文件系统中加载该模型,并将其再次保存到文件系统中。
但是当我第二次从文件系统加载SavedModel时,它失败了,出现了以下异常:
TypeError: f(inputs, training, training, training, training, *, training, training) missing 1 required argument: training
您可以尝试使用以下代码复制此问题:
import tensorflow as tf
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs
model1 = tf.keras.Sequential([
CustomLayer()
])
model1.build((None, 1))
model1.compile()
model1.save("model1")
model2 = tf.keras.models.load_model("model1")
model2.save("model2")
# This line should raise a TypeError.
model3 = tf.keras.models.load_model("model2")
问题存在的原因
问题是TensorFlow SavedModel格式实际上并不序列化自定义Python代码。它只保存由自定义Keras层和其他Python对象生成的TensorFlow图。
默认情况下,tf.keras.models.load_model()
函数不返回Python层。相反,它返回一个占位符层,其中包含TensorFlow计算图的相同部分。我们可以在我的问题中的例子中看到这一点:
>>> model1.layers
[<__main__.CustomLayer at 0x7ff04c14ee20>]
>>> model2.layers
[<keras.saving.saved_model.load.CustomLayer at 0x7ff114fd7be0>]
保存并加载model2
时,TensorFlow不能正确解析CustomLayer.call()
中的*args
和**kwargs
参数。
我不知道实际的bug是在保存代码中,还是在加载代码中,还是两者都有。
真正的修复需要在TensorFlow/Keras中发生,但与此同时,有
<标题>工作区h1> 可以选择以下任何一种解决方法来避免自定义Keras层的序列化错误。修改Layer.call()
上的签名
目前Layer.call()
上的官方方法签名是def call(self, inputs, *args, **kwargs):
但是当TensorFlow尝试加载带有此签名的自定义层的模型时,会抛出TypeError。为了修复这个错误,用def call(self, inputs):
的签名来编写所有的自定义层。如果你的层在训练或推理期间表现不同,那么你可以使用方法签名def call(self, inputs, training=None):
这使得TensorFlow更容易生成在keras.saving.saved_model.load
模块中生成的占位符层。但是这个占位符层仍然与原始Python代码不完全相同。
在tf.keras.models.load_model()
上使用custom_objects
参数可以用原始的Python层而不是占位符层加载模型。只需将一个映射层名的字典传递给Python层类对象。这要求您的代码能够导入原始的Python层。在我的问题的例子可以固定如下:
model3 = tf.keras.models.load_model(
"model2",
custom_objects=dict(
CustomLayer=CustomLayer,
),
)
确保你的图层实现了Layer.get_config()
,并返回一个字典,其中包含从头创建图层所需的所有参数。该图层必须能够用Layer.from_config()
重新创建。
导入Python层并将其添加到Keras的全局注册表
Keras维护一个自定义Python类和其他对象的全局注册表,以便在加载SavedModels时引用。你可以用@tf.keras.utils.register_keras_serializable()
装饰器注册你的自定义Keras层。例如:
@tf.keras.utils.register_keras_serializable(
package="my_python_package"
)
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs
这个方法也要求你的层正确实现Layer.get_config()
。
使用tf.keras.utils.custom_object_scope()
安装Python层对象
与上面两个解决方案非常相似,tf.keras.utils.custom_object_scope()
上下文管理器可以指定反序列化时使用哪些自定义层。