下面是一个使用tensorflow的1.15.0对象检测API的示例。本教程在以下几个方面做得很清楚:
- 如何下载模型
- 如何用.xml文件加载自定义数据库,从中生成.cvs文件,然后生成.record文件
- 如何配置培训管道
- 如何获得tensorboard图
- 如何训练网络节省检查点(使用modelmain.py(
- 如何导出(保存(模型(使用export_inference_graph.py(
然而,我无法完成的是加载保存的模型以使用它。我试过tf.saved_model.loader.load(sess, flags, export_dir
,但我得到了
INFO:tensorflow:Saver not created because there are no variables in the graph to restore.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
export_dir
中给出的文件夹具有以下结构:
+dir
+saved_model
-saved_model.pb
-model.ckpt.data-00000-of-00001
-model.ckpt.index
-checkpoint
-frozen_inference_graph.pb
-model.ckpt.meta
-pipeline.config
我在这里的最终目标是用相机捕捉图像,并将其输入网络进行实时物体检测\作为中间步骤,现在我只想能够提供一张图片并获得输出。我本来可以训练球网,但现在我不能用了。
提前谢谢。
我找到了一个关于如何下载模型的示例,让我可以浏览它\由于示例中下载的文件的文件夹格式与我在代码中获得的格式相同,我只需要对其进行调整
下载模型的orifinal函数是
def load_model(model_name):
base_url = 'http://download.tensorflow.org/models/object_detection/'
model_file = model_name + '.tar.gz'
model_dir = tf.keras.utils.get_file(
fname=model_name,
origin=base_url + model_file,
untar=True)
model_dir = pathlib.Path(model_dir)/"saved_model"
model = tf.saved_model.load(str(model_dir))
model = model.signatures['serving_default']
return model
然后我用这个功能创建了这个新的
def load_local_model(model_path):
model_dir = pathlib.Path(model_path)/"saved_model"
model = tf.saved_model.load(str(model_dir))
model = model.signatures['serving_default']
return model
起初这不起作用,因为tf.saved_model.load
需要3个参数,但通过在同一个示例中导入两个import块解决了这个问题,我不知道import是怎么做到的,也不知道为什么(我会在得到答案后编辑它(,但目前这段代码起作用,这个示例让我们做更多的事情。
导入块是以下
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display
和
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
EDIT真正需要的是以下块。
import os
import pathlib
if "models" in pathlib.Path.cwd().parts:
while "models" in pathlib.Path.cwd().parts:
os.chdir('..')
elif not pathlib.Path('models').exists():
!git clone --depth 1 https://github.com/tensorflow/models
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
%%bash
cd models/research
pip install .
否则此导入块将不起作用
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util