在 Tensorflow 2.0 中迭代无限重复的 tf.data 数据集的正确方法是什么?



TF2.0 文档建议使用 python for 循环迭代数据集:

for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# do training

问题是,如果数据集无限期地重复(据我所知,出于性能原因,这是有意义的(,这个循环将永远不会结束。

我目前正在做的是设置一些我想要迭代的时期和训练步骤:

train_iter = iter(train_dataset)
for i in range(num_epochs):
# do some setup
for step in range(num_batches):
(x_batch, y_batch) = next(train_iter)
# do training
# log metrics

我不确定的是,这是否会对我的训练过程的表现产生负面影响。这会让我的训练运行得更慢,还是通过这样运行我的训练来阻止 Tensorflow 优化我的代码? 最重要的是,设置在纪元期间要处理的批次数可能有点烦人,因为我想在我的数据管道中进行随机扩充。因此,我的数据集中唯一样本的数量在不同的训练会话之间可能会有所不同。不过这不是什么大问题。

我试图通过谷歌找到这个问题的答案,但不幸的是没有运气。

代码的问题,

train_iter = iter(train_dataset)
for i in range(num_epochs):
# do some setup
for step in range(num_batches):
(x_batch, y_batch) = next(train_iter)

就是每epochmodel看到batches的顺序相同,效率不高。

此类代码的示例输出如下所示:

tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)

如上所示,对应于每个Epoch的值是相同的,或者换句话说,每个epochBatches重复(4, 0, 8, 6, 73,1,2,9,5重复三次(。

以不同顺序传递batches的优化和高效方法是使用参数reshuffle_each_iteration=True。示例代码如下所示:

import tensorflow as tf
dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(buffer_size=5, reshuffle_each_iteration=True)
iter(dataset)
buffer_size = 10
batch_size = 2
for epoch in range(num_epochs):
dataset_epoch = dataset.batch(batch_size)
for x, y in dataset_epoch:
print(x,y)

上述代码的输出如下所示,可以观察到与任何批处理对应的值都没有重复:

tf.Tensor(2, shape=(), dtype=int64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(6, shape=(), dtype=int64)
tf.Tensor(9, shape=(), dtype=int64) tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64) tf.Tensor(8, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64) tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(8, shape=(), dtype=int64)
tf.Tensor(9, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(4, shape=(), dtype=int64)

希望这有帮助。快乐学习!

相关内容

最新更新