我定义了以下函数,用于对我的预测进行非最大值抑制(NMS(后处理。
目前,它是为单个预测或输出定义的:
from torchvision import transforms as torchtrans
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
然后将其应用于单个图像:
cpu_device = torch.device("cpu")
# pick one image from the test set
img, target = valid_dataset[3]
# put the model in evaluation mode
model.to(cpu_device)
model.eval()
with torch.no_grad():
output = model([img])[0]
nms_prediction = apply_nms(output, iou_thresh=0.1)
然而,我不确定如何有效地处理来自数据加载器的一整批图像:
cpu_device = torch.device("cpu")
model.eval()
with torch.no_grad():
for images, targets in valid_data_loader:
images = list(img.to(device) for img in images)
outputs = model(images)
outputs = [{k: v.to(cpu_device)for k, v in t.items()} for t in outputs]
#DO NMS POST PROCESSING HERE??
什么是最好的方法?如何将上述定义的函数应用于多个图像?这最好在另一个for循环中完成吗?
查看torchivision文档页面中的Generic Trnasform段落,您可以使用torchvision.transform.Lambda
或使用函数转换。
以下是Lambda
的示例
nms_transform = torchvision.transforms.Lambda(apply_nms)
然后,您可以使用数据集的transform
参数应用转换(也可以创建自定义数据集类(:
dset = MyDset(..., transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), nms_transform()])