我正在尝试在OD模型(用于本地化对象(上运行分类器模型。为了减少延迟,我对 OD 和分类器模型都使用了多处理。输出是正确的,但我得到重复的结果。
我有一台 8 核的机器,所以我正在用pool=mp.Pool(8)
进行多处理 我正在使用map_async
,并有一个可迭代对象作为图像路径列表。 为了将结果作为列表获取,我正在使用.get()
.
一开始我没有在pool.close()
后实施pool.join()
,这是我在浏览了几个站点后确定的。我得到的输出错误是因为我传递给pool.map_async()
的chunksize
。相同输出的重复次数与块大小相同。但根据我对块大小的理解,它应该只创建与块大小相同的批次,并将每个批次传递给一个进程。
return_stuff_classifier=[]
def label_it(image_path):
file_name = image_path
image_name=image_path.split('/')[-1]
frame_id=image_name.split('_')[0]
object_id=image_name.split('_')[1].split('.')[0]
label="gt"
result="0.86" #Here I have explicitly mentioned this, not to go through the
#classifier model prediction
return_stuff_classifier.append((frame_id,object_id,label,result))
return return_stuff_classifier
def multiprocessor():
m_class = mp.Manager()
queue_class = m_class.Queue()
pool_class=mp.Pool(8)
# Here cropped_image_no 24 -> chunk_size_class=3
chunk_size_class=round(cropped_images_no/8)
results_class=pool_class.map_async(label_it,cropped_images,chunk_size_class).get()
#label_it is the method to be multiprocessed
#cropped_images is the list of all image paths to be multiprocessed
pool_class.close()
pool_class.join()
final_results.append(results_class)
输出:
[[['443', '10', 'ugt', '0.85964435'],
['443', '11', 'ugut', '0.48011008'],
['443', '4', 'gut', '0.50242084']],
[['443', '10', 'ugt', '0.85964435'],
['443', '11', 'ugut', '0.48011008'],
['443', '4', 'gut', '0.50242084']],
[['443', '10', 'ugt', '0.85964435'],
['443', '11', 'ugut', '0.48011008'],
['443', '4', 'gut', '0.50242084']],
[['443', '2', 'ugut', '0.8623834'],
['443', '6', 'gt', '0.95684755'],
['443', '1', 'gut', '0.683893']],
[['443', '2', 'ugut', '0.8623834'],
['443', '6', 'gt', '0.95684755'],
['443', '1', 'gut', '0.683893']],
[['443', '2', 'ugut', '0.8623834'],
['443', '6', 'gt', '0.95684755'],
['443', '1', 'gut', '0.683893']]]
预期输出:
[[['443', '10', 'ugt', '0.85964435'],
['443', '11', 'ugut', '0.48011008']
['443', '4', 'gut', '0.50242084']],
[['443', '2', 'ugut', '0.8623834'],
['443', '6', 'gt', '0.95684755'],
['443', '1', 'gut', '0.683893']]]
我认为问题在于,label_it()
函数每次执行时都会将结果附加到return_stuff_classifier
列表中,然后返回整个列表 - 从而返回一个累积先前调用结果的值。发生这种情况的次数由chunksize
控制。
幸运的是,这很容易解决 - 只需返回您附加到列表中的元组即可。如果你这样做,就不再需要列表了。
注意 我必须向代码添加if __name__ == '__main__':
保护,以便它可以在运行 Windows 的计算机上工作,因为子进程在其上的创建方式与在类 Unix 操作系统上的创建方式不同。它应该仍然适用于它们,因此这样做是可移植的。在multiprocessing
模块的编程指南中标题为"安全导入主模块"的小节中,需要执行此操作。
另一个更改是将get()
调用移动到pool_class.join()
之后,因为到那时所有子进程都已结束。在这种情况下,不需要这样做,因为主进程实际上没有进一步的事情要做,但它是从map_async()
中检索结果的规范方式——可能是因为它允许主进程同时执行其他任务,如果它有任何事情要做的话。
import multiprocessing as mp
from pprint import pprint
cropped_images = [f'./image_directory_path/frame_{i}.jpg' for i in range(1, 25)]
#return_stuff_classifier = [] # No longer needed.
def label_it(image_path):
file_name = image_path
image_name = image_path.split('/')[-1]
frame_id = image_name.split('_')[0]
object_id = image_name.split('_')[1].split('.')[0]
label = "gt"
result = "0.86" # Here I have explicitly mentioned this, not to go through the
# classifier model prediction
# return_stuff_classifier.append((frame_id, object_id, label, result))
# return return_stuff_classifier
return (frame_id, object_id, label, result) # Just return the results.
if __name__ == '__main__':
def multiprocessor():
m_class = mp.Manager()
queue_class = m_class.Queue()
pool_class = mp.Pool(8)
final_results = []
# Here cropped_image_no == 24 -> chunk_size_class=3
chunk_size_class = round(len(cropped_images) / 8)
print(f'{chunk_size_class=}')
results_class = pool_class.map_async(label_it, cropped_images, chunk_size_class)
# label_it is the method to be multiprocessed
# cropped_images is the list of all image paths to be multiprocessed
pool_class.close()
pool_class.join()
final_results.append(results_class.get())
pprint(final_results)
multiprocessor()
以下是显示现在没有重复的印刷品:
chunk_size_class=3
[[('frame', '1', 'gt', '0.86'),
('frame', '2', 'gt', '0.86'),
('frame', '3', 'gt', '0.86'),
('frame', '4', 'gt', '0.86'),
('frame', '5', 'gt', '0.86'),
('frame', '6', 'gt', '0.86'),
('frame', '7', 'gt', '0.86'),
('frame', '8', 'gt', '0.86'),
('frame', '9', 'gt', '0.86'),
('frame', '10', 'gt', '0.86'),
('frame', '11', 'gt', '0.86'),
('frame', '12', 'gt', '0.86'),
('frame', '13', 'gt', '0.86'),
('frame', '14', 'gt', '0.86'),
('frame', '15', 'gt', '0.86'),
('frame', '16', 'gt', '0.86'),
('frame', '17', 'gt', '0.86'),
('frame', '18', 'gt', '0.86'),
('frame', '19', 'gt', '0.86'),
('frame', '20', 'gt', '0.86'),
('frame', '21', 'gt', '0.86'),
('frame', '22', 'gt', '0.86'),
('frame', '23', 'gt', '0.86'),
('frame', '24', 'gt', '0.86')]]