TensorFlow DataSet.shuffle似乎没有重复()就没有随机散布



我的代码与TensorFlow 2.0教程具有相似的模式。我希望我的数据集对象在每个时期重新填充。

dataset = tf.data.Dataset.from_tensor_slices(['a','b','c','d'])
dataset = dataset.shuffle(100)
for epoch in range(10):
    for d in dataset:
        print(d)

结果:

tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
...

似乎数据集并未为每个时期散布。我应该为每个时期打电话给.shuffle((?

是的,您应该在内部循环期间调用.shuffle。此外,最好不要混合python代码和TensorFlow代码。

import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c", "d"])
# dataset = dataset.shuffle(2)

@tf.function
def loop():
    for epoch in tf.range(10):
        for d in dataset.shuffle(2):
            tf.print(d)

loop()

循环调用每次都会产生不同的值(并且tf.print打印tf.Tensor的内容,与打印对象的print不同(。