pytorch中的随机数据交换



我想以[0,180]之间的随机度旋转Dataset中的所有图像。如果我编写一个转换函数,并将图像传递给Dataset类的__getitem__函数中的这个函数。这是否意味着:

  1. 每个图像都是随机旋转的
  2. 每个批次中的图像都以相同的程度旋转,但这个程度在批次之间随机变化(调用(

如果你能为我澄清这一点,我将不胜感激。

在映射的数据集中,__getitem__用于从数据集中选择单个元素。

PyTorch/Torchvision中的随机变换的工作方式是,每次调用变换时,它们都会应用一个唯一的随机变换。这意味着:

  1. 数据集中的每个图像确实是随机旋转的,但数量不相同

  2. 此外,批处理中的图像会得到不同的转换。换句话说,批处理中的元素不会共享相同的转换参数。


以下是一个伪数据集的最小示例:

class D(Dataset):
def __init__(self, n):
super().__init__()
self.n = n
self.transforms = T.Lambda(lambda x: x*randint(0,10))

def __len__(self):
return self.n
def __getitem__(self, index):
x = self.transforms(index)
return x

在这里,您可以看到interintra批次的随机转换器:

>>> dl = DataLoader(D(10), batch_size=2)
>>> for i, x  in enumerate(dl):
...     print(f'batch {i}: elements {2*i} and {2*i+1} = {x.tolist()}')
batch 0: elements 0 and 1 = [0, 2]
batch 1: elements 2 and 3 = [14, 27]
batch 2: elements 4 and 5 = [32, 40]
batch 3: elements 6 and 7 = [60, 0]
batch 4: elements 8 and 9 = [80, 27]

最新更新