摘要
我的问题由以下人员组成:
- 我展示我的项目、工作环境和工作流程的背景
- 详细问题
- 我的代码的相关部分
- 我试图解决我的问题的解决方案
- 问题提醒
上下文
我已经编写了一个原始超分辨率GAN降级版本的Python Keras实现。现在我想使用Google Firebase机器学习工具包进行测试,将其托管在Google服务器中。这就是为什么我必须将我的Keras程序转换为TensorFlow Lite程序。
环境和工作流程(有问题(
我正在Google Colab工作环境中训练我的程序:在那里,我已经安装了TF 2.0.0-beta1
(这个选择是由这个不正确的答案驱动的:https://datascience.stackexchange.com/a/57408/78409(。
工作流程(和问题(:
-
我在本地编写我的Python Keras程序,记住它将在TF 2上运行。所以我使用 TF 2 导入,例如:
from tensorflow.keras.optimizers import Adam
和from tensorflow.keras.layers import Conv2D, BatchNormalization
-
我将代码发送到云端硬盘
-
我运行我的谷歌Colab笔记本没有任何问题:TF 2被使用。
-
我在云端硬盘中获取输出模型,然后下载。
-
我尝试通过执行以下 CLI 将此模型转换为 TFLite 格式:
tflite_convert --output_file=srgan.tflite --keras_model_file=srgan.h5
:这里出现问题。
问题所在
以前的 CLI 不会从 TF (Keras( 模型输出 TF Lite 转换模型,而是输出此错误:
值错误:未知损失函数:build_vgg19_loss_network
函数build_vgg19_loss_network
是我已经实现的自定义损失函数,必须由 GAN 使用。
引发此问题的代码部分
呈现自定义损失函数
自定义损失函数是这样实现的:
def build_vgg19_loss_network(ground_truth_image, predicted_image):
loss_model = Vgg19Loss.define_loss_model(high_resolution_shape)
return mean(square(loss_model(ground_truth_image) - loss_model(predicted_image)))
使用我的自定义损失函数编译生成器网络
generator_model.compile(optimizer=the_optimizer, loss=build_vgg19_loss_network)
为了解决问题,我试图做什么
当我在 StackOverflow 上阅读它时(这个问题开头的链接(,TF 2 被认为足以输出一个 Keras 模型,该模型将由我的
tflite_convert
CLI 正确处理。但显然不是。当我在 GitHub 上阅读它时,我尝试通过添加以下行在 Keras 的损失函数中手动设置我的自定义损失函数:
import tensorflow.keras.losses tensorflow.keras.losses.build_vgg19_loss_network = build_vgg19_loss_network
.它没有用。我在GitHub上读到我可以将自定义对象与
load_model
Keras 函数一起使用:但我只想使用 Keras 函数compile
。不load_model
.
我的最后一个问题
我只想对我的代码进行微小的更改,因为它工作正常。所以我不想,例如,用load_model
替换compile
.有了这个约束,你能帮我,请,让我的 CLItflite_convert
与我的自定义损失函数一起使用吗?
由于您声称 TFLite 转换由于自定义损失函数而失败,因此您可以保存模型文件而不保留优化器详细信息。为此,请将include_optimizer
参数设置为 False,如下所示:
model.save('model.h5', include_optimizer=False)
现在,如果模型内的所有层都是可转换的,则它们应该转换为 TFLite 文件。
编辑: 然后,您可以像这样转换 h5 文件:
import tensorflow as tf
model = tf.keras.models.load_model('model.h5') # srgan.h5 for you
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
此处记录了克服 TFLite 转换中不支持的运算符的常规做法。
我遇到了同样的错误。我建议将损失更改为"mse",因为您已经有一个训练良好的模型,并且不需要使用 .tflite 文件进行训练。