如何在数据加载器(PyTorch)的多个图像上有效地应用NMS(非最大抑制)



我定义了以下函数,用于对我的预测进行非最大值抑制(NMS(后处理。

目前,它是为单个预测或输出定义的:

from torchvision import transforms as torchtrans  
def apply_nms(orig_prediction, iou_thresh=0.3):

# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)

final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]

return final_prediction

然后将其应用于单个图像:

cpu_device = torch.device("cpu")
# pick one image from the test set
img, target = valid_dataset[3]
# put the model in evaluation mode
model.to(cpu_device)
model.eval()
with torch.no_grad():
output = model([img])[0]

nms_prediction = apply_nms(output, iou_thresh=0.1)

然而,我不确定如何有效地处理来自数据加载器的一整批图像:

cpu_device = torch.device("cpu")
model.eval()
with torch.no_grad():
for images, targets in valid_data_loader:
images = list(img.to(device) for img in images)

outputs = model(images)
outputs = [{k: v.to(cpu_device)for k, v in t.items()} for t in outputs]
#DO NMS POST PROCESSING HERE??

什么是最好的方法?如何将上述定义的函数应用于多个图像?这最好在另一个for循环中完成吗?

查看torchivision文档页面中的Generic Trnasform段落,您可以使用torchvision.transform.Lambda或使用函数转换。

以下是Lambda的示例

nms_transform = torchvision.transforms.Lambda(apply_nms)

然后,您可以使用数据集的transform参数应用转换(也可以创建自定义数据集类(:

dset = MyDset(..., transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), nms_transform()])

最新更新