如何在数据加载器中使用Batchsampler



我需要在pytorchDataLoader中使用BatchSampler,而不是多次调用数据集的__getitem__(远程数据集,每个查询都很昂贵(
我无法理解如何将批次采样器用于任何给定的数据集。

例如

class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, idx):
return self.ddf[idx] --------> This is as expensive as a batch call
def get_batch(self, batch_idx):
return self.ddf[batch_idx]
my_loader = DataLoader(MyDataset(remote_ddf), 
batch_sampler=BatchSampler(Sampler(), batch_size=3))

我不明白的是,我如何使用get_batch函数而不是__getitem__函数,这在网上或torch文档中都没有找到任何例子。
编辑:根据Szymon Maszke的回答,这就是我尝试的,然而,__get_item__每次调用都会得到一个索引,而不是batch_size大小的列表

class Dataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, batch_idx):  ------> here I get only one index
return self.wiki_df.loc[batch_idx]

loader = DataLoader(
dataset=dataset,
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
num_workers=self.hparams.num_data_workers,
)

你不能用get_batch代替__getitem__,我认为这样做没有意义。

torch.utils.data.BatchSamplerSampler()实例(在本例中为3(中获取索引,并将其返回为list,以便在MyDataset__getitem__方法中使用这些索引(请检查源代码,如果需要,大多数采样器和数据相关实用程序都很容易使用(。

我假设您的self.ddf支持列表切片(例如,self.ddf[[25, 44, 115]]正确返回值,并且只使用一个昂贵的调用(。在这种情况下,只需将get_batch切换为__getitem__即可。

class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, batch_idx):
return self.ddf[batch_idx] -> batch_idx is a list

编辑:必须将batch_sampler指定为sampler,否则批次将被划分为单个索引。这应该很好:

loader = DataLoader(
dataset=dataset,
# This line below!
sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
),
num_workers=self.hparams.num_data_workers,
)

相关内容

  • 没有找到相关文章

最新更新