如何在PyTourch中创建一个平衡循环迭代器



假设我有两个类。其中一个我只有17个样本,另外83个。我希望每个历元中每个类的数据量始终相等(在这种情况下是17乘17(。此外,我想在类中滑动采样一个窗口,在那里我每个历元都有更多的数据(前17个,下17个,…(

目前,我有一个循环采样迭代器,如下所示:

class CyclicIterator:
def __init__(self, loader, sampler):
self.loader = loader
self.sampler = sampler
self.epoch = 0
self._next_epoch()
def _next_epoch(self):
self.iterator = iter(self.loader)
self.epoch += 1
def __len__(self):
return len(self.loader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iterator)
except StopIteration:
self._next_epoch()
return next(self.iterator)

我想知道如何强制每个类的所有样本在每个历元中具有相等的计数?

对于平衡批次,即每个批次中每个类别的样本数量相等(或接近相等(,有一些方法:

-过采样(使较小大小的类过采样,直到达到最高采样数(。在这种方法中,您可以使用以下代码:

https://github.com/galatolofederico/pytorch-balanced-batch

-欠采样(根据最小类别编号提供所有类别的样本数量(。根据我的经验,以下函数与使用PyTorch库的函数类似:

torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

其中权重是每个样本的概率,这取决于每个类别有多少样本,例如,如果数据很简单,数据=[0,1,0,1],类"0"计数为3,类"1"计数为2,则权重向量为[1/3,1/2,1/3,1/3,1/2]。有了这个,你可以调用WeightedRamdomSampler,它会为你制作的。您需要在Dataloader中调用它。设置它的代码是:

sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
train_dataloader = DataLoader(dataset_train, batch_size=mini_batch,
sampler=sampler, shuffle=False,
num_workers=1)

相关内容

最新更新