Pytorch数据加载器将图像连接到输入图像



在PyTorch数据加载器中,如何将图像(比如x.jpg(按带连接到每个输入图像。即,实际上我将有4波段输入(3波段输入jpg和1波段x.jpg。如何实现它。

请找到下面的例子,我目前的数据加载器只是为了加载图像。对此,我想将x.jpg添加到";图像";(即输入图像,不屏蔽(

from PIL import Image
class lakeDataSet(Dataset):
def __init__(self, root, transform):
super().__init__()
self.root = root
self.img_dir = os.path.join(root,'image-c3/c3-crop')   #9UAV
self.mask_dir = os.path.join(root,'label-c3/c3-crop')
# self.mask_dir = os.path.join(root,'test')
self.files = [fname for fname in os.listdir(self.img_dir) if fname.endswith('.jpg')]
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self,I):
fname = self.files[i]
img_path = os.path.join(self.img_dir, fname)
mask_path = os.path.join(self.mask_dir, fname)
img = self.transform(Image.open(img_path))
mask = self.transform(Image.open(mask_path))
return img, mask

我想self.transform已经有了ToTensor。否则,您也应该指定它。

然后你可以只凹入第一个维度。像

x_jpg = self.transform(Image.open('x.jpg'))
img = torch.cat((img, x_jpg), 0)

x.jpg必须只有1个通道,如果它是RGB,那么很明显它将变成6个通道,而不是4个。

最新更新