是否可以将TensorFlow中的训练模型转换为可用于迁移学习的对象



我想使用这里描述的迁移学习:https://www.tensorflow.org/tutorials/images/transfer_learning

问题是,我试图用作基础模型的模型不是已知的内置Keras模型之一,如MobileNetV2。因此,我想我需要完成以下第一步(步骤1(,才能完成迁移学习教程中提到的内容(步骤2-6(
1.从包含Saved_model文件的目录中加载模型
2.冻结模型(使其可训练参数不可更改(
3。制作一个单独的层并将其堆叠在冻结模型的顶部
4。训练生成的模型
5.保存新训练的模型
6.使用新训练的模型进行预测。

我的问题是关于第一步。当我尝试使用以下Python代码/脚本加载模型时,我遇到了一个错误,我不知道如何修复它:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
tf.saved_model.load(
export_dir='/dir_to_the_model_files/', tags=None
)

错误为:

OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

我还认为,可能有一种方法可以将TensorFlow文件(包括(saved_model.ckpt-0.data-0000-of-0001(转换为Keras API可读的文件(例如h5py.File格式(,这可能有助于类似于上述教程的迁移学习。因此,我可以将类似的方法应用于以下方法来提取基本模型并执行下一步操作。

base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')

或者最好使用以下方法https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model:

tf.keras.models.load_model(
filepath, custom_objects=None, compile=True
)

更新:我尝试了以下方法,但不起作用(tf是使用兼容版本import tensorflow.compat.v1. as tf导入的(:

with tf.Session() as sess:
saver = tf.train.import_meta_graph('/dir_to_the_model_files/saved_model.ckpt-0.meta')
saver.restore(sess, "/dir_to_the_model_files/saved_model.ckpt-0")
loaded = tf.saved_model.load(sess,tags=None,export_dir="/dir_to_the_model_files",import_scope=None)

它返回以下警告和错误:

WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'metric_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from /dir_to_the_model_files/saved_model.ckpt-0
<tensorflow.python.training.saver.Saver object at 0x2aaab4824a50>
WARNING:tensorflow:From <ipython-input-3-b8fd24f6b841>:9: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

tf.saved_model.load的TensorFlow文档可能会有所帮助:

tf.estimator.estimator或1.x SavedModel API中的SavedModels具有平面图而不是tf.function对象。这些SavedModels将具有与其在.signatums中的签名相对应的函数属性,但也有一个.preeme方法,该方法允许您提取新子图的函数。这相当于导入SavedModel和命名在会话中从TensorFlow馈送和获取1.x.

您可能不得不使用不推荐使用的v1 api调用https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/load

最新更新