TensorFlow 数据集方法在输入函数中调用序列



tf.data.Dataset中有很多方法,如batch((,shard((,shuffle((,prefetch((,map((...等。 通常,当我们实现一个input_fn时,我们会根据我们的意愿调用它们。

我想知道当我们以不同的顺序调用这些方法时,是否会对程序产生任何影响?例如,它们在以下两个调用序列中是否相同?

dataset = dataset.shuffle().batch()
dataset = dataset.batch().shuffle()

我想知道当我们调用这些方法时是否会对程序产生任何影响 在不同的顺序?

是的,有区别。几乎总是,shuffle()应该在batch()之前调用,因为我们想洗牌记录而不是批处理。

tf.data.Dataset的转换按调用它们的相同顺序应用。

Batch 将其输入的连续元素合并为输出中的单个批处理元素。

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.arange(19))
for batch in dataset.batch(5):
print(batch)

输出:

tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)

当我们在将数据馈送到网络之前对数据进行洗牌时。这会用buffer_size元素填充缓冲区,然后从此缓冲区中随机采样元素,用新元素替换所选元素。对于完美的随机播放,缓冲区大小应等于数据集的完整大小。

for batch in dataset.shuffle(5).batch(5):
print(batch)

输出:

tf.Tensor([2 0 1 4 8], shape=(5,), dtype=int64)
tf.Tensor([ 9  3  7  6 11], shape=(5,), dtype=int64)
tf.Tensor([12 14 15  5 13], shape=(5,), dtype=int64)
tf.Tensor([17 18 16 10], shape=(4,), dtype=int64)

你可以看到结果不是均匀的,但足够好。

但是,如果以不同的顺序应用这些方法,则会得到意外的结果。它随机排列批处理,而不是记录。

for batch in dataset.batch(5).shuffle(5):
print(batch)

输出:

tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)

最新更新