使用pytorch对数据集进行过采样



我对PyTorch和python很陌生。我有一个二进制分类问题,其中一个类的样本比另一个多,所以我决定通过对样本数量较少的类进行更多增强来对其进行过采样,所以例如,我将从一个类中的一个样本中生成7个图像,而对于另一个类,我将在一个样本中生成3个图像。我正在使用imguag与PyTorch进行扩充,所以我不确定哪种更好,先扩充我的数据集,然后将其传递给torch.utils.data.dataset类,或者读取数据并在dataset类的init函数内扩充它。

我认为还有另一种方法可以处理不平衡的数据,nn.BCELoss是二进制分类问题的常见选择,您可以设置pos_weight来平衡正样本和负样本。若您这样做,您可以对所有样本应用相同的扩增。这是代码:

# defines the augmentation
transform = transforms.Compose([transforms.RandomRotation(20),
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# initializes the data set
dataset = Dataset(train_data_path, transforms=transform)
# defines the loss function
criterion = torch.nn.BCELoss(torch.tensor([10.]))

最新更新