如何获取我在Pytorch中放入Dataloader的图像的文件名



我使用pytorch加载这样的图像:

inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size)
inf_dataloader = DataLoader(inf_data, batch_size=1, shuffle=True, num_workers=2)

然后:

with torch.no_grad():
for batch_idx, (img_np, img_tor) in enumerate(inf_dataloader, start=1):
img_tor = img_tor.to(device)
pred_masks, _ = model(img_tor)

但是我想要得到图像的文件名。有人能帮我吗?非常感谢!

DataLoader基本上无法获取文件名。但在Dataset中,也就是上面提到的问题中的InfDataloader,您可以从张量中获得文件的名称。

class InfDataloader(Dataset):
"""
Dataloader for Inference.
"""
def __init__(self, img_folder, target_size=256):
self.imgs_folder = img_folder
self.img_paths = []
img_path = self.imgs_folder + '/'
img_list = os.listdir(img_path)
img_list.sort()
img_list.sort(key=lambda x: int(x[:-4]))  ##文件名按数字排序
img_nums = len(img_list)
for i in range(img_nums):
img_name = img_path + img_list[i]
self.img_paths.append(img_name)
# self.img_paths = sorted(glob.glob(self.imgs_folder + '/*'))
print(self.img_paths)

self.target_size = target_size
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def __getitem__(self, idx):
"""
__getitem__ for inference
:param idx: Index of the image
:return: img_np is a numpy RGB-image of shape H x W x C with pixel values in range 0-255.
And img_tor is a torch tensor, RGB, C x H x W in shape and normalized.
"""
img = cv2.imread(self.img_paths[idx])
name = self.img_paths[idx]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Pad images to target size
img_np = pad_resize_image(img, None, self.target_size)
img_tor = img_np.astype(np.float32)
img_tor = img_tor / 255.0
img_tor = np.transpose(img_tor, axes=(2, 0, 1))
img_tor = torch.from_numpy(img_tor).float()
img_tor = self.normalize(img_tor)
return img_np, img_tor, name

在这里我添加行name = self.img_paths[idx]并返回。

所以,

with torch.no_grad():
for batch_idx, (img_np, img_tor, name) in enumerate(inf_dataloader, start=1):
img_tor = img_tor.to(device)
pred_masks, _ = model(img_tor)

我可以知道这个名字。

相关内容

  • 没有找到相关文章

最新更新