我的任务其实很简单,但我不知道如何实现它。我打算在我的 ML 算法中使用它,但让我们简化示例。假设有一个如下所示的生成器:
nums = ((i+1) for i in range(4))
以上,将产生我们1
,2
,3
和4
。
假设上面的生成器返回单个"样本"。我想编写一个生成器方法来批处理它们。假设批大小为2
。因此,如果调用此新方法:
def batch_generator(batch_size):
do something on nums
yield batches of size batch_size
然后这个批处理生成器的输出将是:1
和2
然后3
和4
.元组/列表无关紧要。重要的是如何返回这些批次。我找到了这个yield from
Python 3.3 中引入的关键字,但它似乎对我的情况没有用。
显然,如果我们有5
个数字而不是4
,并且batch_size
是2
,我们将省略第一个生成器的最后一个产生值。
我自己的解决方案可能是,
nums = (i+1 for i in range(4))
def giveBatch(gen, numOfItems):
try:
return [next(gen) for i in range(numOfItems)]
except StopIteration:
pass
giveBatch(nums, 2)
# [1, 2]
giveBatch(nums, 2)
# [3, 4]
另一种解决方案是使用@Bharel提到的grouper
。我比较了运行这两种解决方案所需的时间。没有太大区别。我想它可以被忽略。
from timeit import timeit
def wrapper(func, *args, **kwargs):
def wrapped():
return func(*args, **kwargs)
return wrapped
nums = (i+1 for i in range(1000000))
wrappedGiveBatch = wrapper(giveBatch, nums, 2)
timeit(wrappedGiveBatch, number=1000000)
# ~ 0.998439
wrappedGrouper = wrapper(grouper, nums, 2)
timeit(wrappedGrouper, number=1000000)
# ~ 0.734342
在迭代工具下,你有一个代码片段,它就是这样做的:
from itertools import zip_longest
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
您不必每次都调用方法,而是拥有一个迭代器,它可以返回批处理,更高效、更快,并且可以处理极端情况,例如过早耗尽数据而不会丢失数据。
这正是我所需要的:
def giveBatch(numOfItems):
nums = (i+1 for i in range(7))
while True:
yield [next(nums) for i in range(numOfItems)]