我们如何将 .pth 模型转换为 .pb 文件?



我已经通过使用pytorch获得了完整的模型,但是我想将.pth文件转换为.pb,可以在Tensorflow中使用。有人有一些想法吗?

您可以使用 ONNX:开放式神经网络交换格式

要将.pth文件转换为.pb首先,您需要将 PyTorch 中定义的模型导出到 ONNX,然后将 ONNX 模型导入 Tensorflow (PyTorch => ONNX => Tensorflow(

这是 MNISTModel 使用 ONNX 将 PyTorch 模型转换为 Tensorflow 的示例 来自 onnx/教程

将训练好的模型保存到文件

torch.save(model.state_dict(), 'output/mnist.pth')

从文件加载训练的模型

trained_model = Net()
trained_model.load_state_dict(torch.load('output/mnist.pth'))
# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx")

加载 ONNX 文件

model = onnx.load('output/mnist.onnx')
# Import the ONNX model to Tensorflow
tf_rep = prepare(model)

将张量流模型保存到文件中

tf_rep.export_graph('output/mnist.pb')

正如@tsveti_iko在评论中指出的那样

注意:prepare()内置在onnx-tf中,因此您首先需要像pip install onnx-tf这样通过控制台安装它,然后将其导入到代码中,如下所示:import onnx from onnx_tf.backend import prepare之后,您终于可以按照答案中的说明使用它了。

如果您使用的是TF 1.15或更低版本,则可能不会发现上述代码有用,因为您最终会解决不匹配的版本错误
所以这是适用于TF 1.X的所有版本匹配代码

Keras                2.3.0
Keras-Applications   1.0.8
Keras-Preprocessing  1.1.2
numpy                1.21.5
onnx                 1.8.0
onnx-tf              1.3.0
protobuf             3.19.4
tensorboard          1.15.0
tensorflow           1.15.0
tensorflow-estimator 1.15.1
torch                1.6.0+cpu
torchvision          0.7.0+cpu

在拥有所有这些软件包后,请使用Dishin的答案

注意:Variable在较新版本的torch中已弃用

最新更新