我在树莓派上使用官方tensorflow lite示例中的detect.py脚本。我想打印检测到的对象的类名。例如,如果对象识别模型检测到外科口罩,它将打印一个";surgical_mask&;"在命令行中。我可以打印detection_result,它在命令行中显示这些结果:
detections {
bounding_box {
origin_x: 135
origin_y: 14
width: 478
height: 457
}
classes {
index: 0
score: 0.875
class_name: "surgical_mask"
}
}
我不知道如何打印&;class_name&;下面是detect.py
的代码# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main script to run the object detection routine."""
import argparse
import sys
import time
import cv2
from tflite_support.task import core
from tflite_support.task import processor
from tflite_support.task import vision
import utils
def run(model: str, camera_id: int, width: int, height: int, num_threads: int,
enable_edgetpu: bool) -> None:
"""Continuously run inference on images acquired from the camera.
Args:
model: Name of the TFLite object detection model.
camera_id: The camera id to be passed to OpenCV.
width: The width of the frame captured from the camera.
height: The height of the frame captured from the camera.
num_threads: The number of CPU threads to run the model.
enable_edgetpu: True/False whether the model is a EdgeTPU model.
"""
# Variables to calculate FPS
counter, fps = 0, 0
start_time = time.time()
# Start capturing video input from the camera
cap = cv2.VideoCapture(camera_id)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
# Visualization parameters
row_size = 20 # pixels
left_margin = 24 # pixels
text_color = (0, 0, 255) # red
font_size = 1
font_thickness = 1
fps_avg_frame_count = 10
# Initialize the object detection model
base_options = core.BaseOptions(
file_name=model, use_coral=enable_edgetpu, num_threads=num_threads)
detection_options = processor.DetectionOptions(
max_results=3, score_threshold=0.3)
options = vision.ObjectDetectorOptions(
base_options=base_options, detection_options=detection_options)
detector = vision.ObjectDetector.create_from_options(options)
# Continuously capture images from the camera and run inference
while cap.isOpened():
success, image = cap.read()
if not success:
sys.exit(
'ERROR: Unable to read from webcam. Please verify your webcam settings.'
)
counter += 1
image = cv2.flip(image, 1)
# Convert the image from BGR to RGB as required by the TFLite model.
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Create a TensorImage object from the RGB image.
input_tensor = vision.TensorImage.create_from_array(rgb_image)
# Run object detection estimation using the model.
detection_result = detector.detect(input_tensor)
# Draw keypoints and edges on input image
image = utils.visualize(image, detection_result)
# Calculate the FPS
if counter % fps_avg_frame_count == 0:
end_time = time.time()
fps = fps_avg_frame_count / (end_time - start_time)
start_time = time.time()
# Show the FPS
fps_text = 'FPS = {:.1f}'.format(fps)
text_location = (left_margin, row_size)
cv2.putText(image, fps_text, text_location, cv2.FONT_HERSHEY_PLAIN,
font_size, text_color, font_thickness)
print(detection_result)
# Stop the program if the ESC key is pressed.
if cv2.waitKey(1) == 27:
break
cv2.imshow('object_detector', image)
cap.release()
cv2.destroyAllWindows()
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--model',
help='Path of the object detection model.',
required=False,
default='efficientdet_lite0.tflite')
parser.add_argument(
'--cameraId', help='Id of camera.', required=False, type=int, default=0)
parser.add_argument(
'--frameWidth',
help='Width of frame to capture from camera.',
required=False,
type=int,
default=640)
parser.add_argument(
'--frameHeight',
help='Height of frame to capture from camera.',
required=False,
type=int,
default=480)
parser.add_argument(
'--numThreads',
help='Number of CPU threads to run the model.',
required=False,
type=int,
default=4)
parser.add_argument(
'--enableEdgeTPU',
help='Whether to run the model on EdgeTPU.',
action='store_true',
required=False,
default=False)
args = parser.parse_args()
run(args.model, int(args.cameraId), args.frameWidth, args.frameHeight,
int(args.numThreads), bool(args.enableEdgeTPU))
if __name__ == '__main__':
main()
您可以通过执行以下操作获得检测结果的数量
Num_of_detections = len(detection_result)
然后使用for循环遍历结果
打印(detection_result.detections [x] . class [0] .class_name)
我是新的stackoverflow,但我有一个脚本,将做的任务。
detectionIntermediate = (detection_result.detections)
lengthOfResult = (len(detectionIntermediate))
if(lengthOfResult>0):
print("Detection Results: ")
#print(detectionIntermediate )
li = list(str(detectionIntermediate).split('n'))
detectedClassData = li[9] #output is in form of 'class_name: "door"'
formattedClassOutput = detectedClassData.replace("class_name: ","")
formattedClassOutput= formattedClassOutput.replace('"','')
print(formattedClassOutput)