我使用Tensorflow Object Detection API训练了一个自定义图像,并使用训练的数据运行了对象检测教程。我有一个与加载标签地图相关的错误。我已经检查了标签图像文件,它似乎与字典内容一致。我不太明白为什么会发生错误。
代码:
# What model to download.
MODEL_NAME = 'new_graph.pb'
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = 'training/labelmap.pbtxt'
NUM_CLASSES=3
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
category_index = label_map_util.convert_label_map_to_categories(PATH_TO_LABELS , max_num_classes=NUM_CLASSES, use_display_name=True)
错误:
AttributeError Traceback (most recent call last)
<ipython-input-27-7acf82e14013> in <module>
1 #category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
2
----> 3 category_index = label_map_util.convert_label_map_to_categories(PATH_TO_LABELS , max_num_classes=NUM_CLASSES)
4
D:me1eyeNew folder29082020modelsresearchobject_detectionutilslabel_map_util.py in convert_label_map_to_categories(label_map, max_num_classes, use_display_name)
118 })
119 return categories
--> 120 for item in label_map.item:
121 if not 0 < item.id <= max_num_classes:
122 logging.info(
AttributeError: 'str' object has no attribute 'item'
labelmap.pbtxt文件:
item {
id: 1
name: 'Cat'
}
item {
id: 2
name: 'Grabes'
}
item {
id: 3
name: 'Olive'
}
需要更改以下内容:
从utils导入label_map_util
---->从object_detection.utils导入label_map_util
从utils将visualization_utils导入为vis_util
---->从object_detection.utils将visualization_utils导入为vis_util
使用convert_label_map_to_categories
时,需要先用load_labelmap
加载地图数据。您的代码正在处理文件名,而不是文件数据。
试试这个代码:
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
category_index = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)