repeat()在创建tf.data.Dataset对象时有什么用途



我正在重现TensorFlow的时间序列预测教程的代码。

他们使用tf.data对数据集进行混洗、批处理和缓存。更确切地说,他们做以下事情:

BATCH_SIZE = 256
BUFFER_SIZE = 10000
train_univariate = tf.data.Dataset.from_tensor_slices((x_train_uni, y_train_uni))
train_univariate = train_univariate.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
val_univariate = tf.data.Dataset.from_tensor_slices((x_val_uni, y_val_uni))
val_univariate = val_univariate.batch(BATCH_SIZE).repeat()

我不明白他们为什么使用repeat(),更不明白为什么不指定repeat的count参数。让这个过程无限期地重复有什么意义?算法如何读取无限大数据集中的所有元素?

在tensorflow联邦图像分类教程中可以看到,重复方法用于使用数据集的重复,这也将指示训练的时期数

因此,使用.reeat(NUM_EPOCHS(,其中NUM_EPOCHS是训练的历元数。

最新更新