在python中,有没有更优雅的方法可以通过批处理遍历数据



我想在python中实现我自己的数据加载器。目标是按小批量随机遍历数据集,我想知道是否有更优雅的方法来实现它

例如,我有一个数据集dataset=[1,2,3,4,5,6,7,8,9]。当为batch_size=3时,返回的三个批次应为:[1,5,7],[2,3,4],[6,8,9]。这是我的成就:

import numpy as np
class DataLoader:
def __init__(self, data: list, batch_size: int):
self.data = data
self.batch_size = batch_size
self.samples_reserve = None
def _reset(self):
self.samples_reserve = np.arange(len(self.data)).tolist()
def __iter__(self):
self._reset()
return self
def __next__(self):
if len(self.samples_reserve) == 0:
raise StopIteration
samples_choice = set(np.random.choice(self.samples_reserve, self.batch_size, replace=False))
self.samples_reserve = list(set(self.samples_reserve) - samples_choice)
return list(samples_choice)
def __len__(self):
return int(len(self.data) / self.batch_size)
if __name__ == '__main__':
for i in DataLoader([1,2,3,4,5,6,7,8,9], 3):
print(i)

__next__函数必须通过在setlist之间传输数据来维护当前保留的数据。我想知道我是否可以用一种更优雅的方式实现以下代码,例如,有没有一些api函数可以直接使用,例如sample?

samples_choice = set(np.random.choice(self.samples_reserve, self.batch_size, replace=False))
self.samples_reserve = list(set(self.samples_reserve) - samples_choice)

我的解决方案,基于numpy.random.Generator.choice:

def get_samples(dataset, size):
if not dataset:
return []
rng = np.random.default_rng()
if size >= len(dataset):
size = len(dataset)
sd = set(dataset)
samples = []
while sd:
sample = rng.choice(list(sd), min(size, len(sd)), replace=False)
samples.append(list(sample))
sd -= set(sample)
return samples
print(get_samples(list(range(20)), 7))

相关内容

最新更新