如何将 Tensorflow Keras 自定义对象与tf.saved_model一起使用.资产?



我有一个自定义的 Keras 层,它从 pickle 文件中读取以初始化一些权重,我希望能够在其上使用tf.keras.utils.register_keras_serializable()。问题是我的__init__函数采用 pickle 文件的路径,当层再次反序列化时,该文件可能不可用。从理论上讲,Keras 资产应该使该层更具可移植性,但我无法弄清楚如何让它与该层的get_config()一起工作。

我的代码的准系统版本:

@tf.keras.utils.register_keras_serializable()
class AssetLayer(tf.keras.layers.Layer):
def __init__(self, asset_path, **kwargs):
super().__init__(**kwargs)
self.asset_path = asset_path
self.asset = tf.saved_model.Asset(asset_path)
data = tf.io.read_file(self.asset)
# do something with data
def get_config(self):
return {
**super().get_config(),
"asset_path": self.asset_path,
}
def call(self, arg):
# arbitrary call function
return arg

如果使用此层的模型使用tf.keras.models.load_model()加载,Keras 将调用get_config()以使用保存的asset_path重新初始化该层,该在反序列化时可能没有指向正确的位置。理想情况下,它会指向保存资产的路径,但我不知道如何让它做到这一点。

例如,我尝试过这段代码

!echo abcd > file.txt
model = tf.keras.Sequential([AssetLayer("file.txt")])
model(tf.ones(3))
model.save("test")
# reloading
!rm file.txt
reloaded_model = tf.keras.models.load_model("test")

这给了我一个错误,说找不到file.txt

我还尝试完全删除get_config()功能。这样就可以成功重新加载图层,同时保留对asset变量的访问权限,但无法访问图层中的其他属性(例如self.asset_path)。这对于调试目的并不理想,所以我想知道是否有更好的方法。

我目前正在使用Tensorflow 2.5.0'

编辑代码:在这部分之前,代码很好。由于以下原因正在复制

问题
!rm file.txt

(所以我把它放在最后)

!echo abcd > file.txt
model = tf.keras.Sequential([AssetLayer("file.txt")])
model(tf.ones(3))
model.save("./content/sample_data/test.h5")
# reloading
reloaded_model = tf.keras.models.load_model("/content/content/sample_data/test.h5")
reloaded_model.summary()
!rm file.txt

参考: https://www.tensorflow.org/guide/keras/save_and_serialize

似乎"tf.saved_model。资产"不支持"tf.keras.models.load_model" 尝试改用 tf.saved_model.save/tf.saved_model.load

相关内容

  • 没有找到相关文章

最新更新