我正在使用一个训练神经网络的代码。代码使用PyTorch的DataLoader为每次迭代加载数据。代码如下
for step, data in enumerate(dataloader, 0):
............................................................
output = neuralnetwork_model(data)
.............................................................
这里的步骤是一个整数,给出值0、1、2、3、.......数据给出了每一步的一批样本。代码在每一步将相应的批次传递给神经网络。
我需要在第n步访问第n+1步的数据,我需要这样的东西
for step, data in enumerate(dataloader, 0):
............................................................
output = neuralnetwork_model(data)
access = data_of_next_step
.............................................................
我怎样才能做到这一点?
在迭代级别执行这样的操作似乎比必须更改数据加载器的实现更方便。查看遍历具有重叠的n
连续元素,您可以使用itertools.tee
:
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip(a, b)
因此,您只需在包装的数据加载器上迭代:
>>> for batch1, batch2 pairwise(dataloader)
... # batch1 is current batch
... # batch2 is batch of following step