预测tensorflow与repeat的使用混淆



我偶然发现了这本关于预测的笔记本。我是通过这篇文章得到它的。

我对

下面的第2行和第4行感到困惑
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.cache().shuffle(buffer_size).batch(batch_size).repeat()
val_data = tf.data.Dataset.from_tensor_slices((x_vali, y_vali))
val_data = val_data.batch(batch_size).repeat()

我理解我们正在尝试洗牌我们的数据,因为我们不想以串行顺序向我们的模型提供数据。在额外的阅读中,我意识到最好让buffer_size与数据集的大小相同。但我不确定repeat在这种情况下做什么。谁能解释一下这里正在做什么,repeat的作用是什么?

我也看了这一页,看到下面的文字,但仍然不清楚。

The following methods in tf.Dataset :
repeat( count=0 ) The method repeats the dataset count number of times.
shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.

不传递任何参数给count的repeat调用使这个数据集无限重复。

在python术语中,数据集是python可迭代对象的子类。如果您有一个tf.data.Dataset类型的对象ds,那么您可以执行iter(ds)。如果数据集是由repeat()生成的,那么它永远不会耗尽项目,也就是说,它永远不会抛出StopIteration异常。

在您引用的笔记本中,将对tf.keras.Model.fit()的调用传递给参数steps_per_epoch的参数100。这意味着数据集应该无限重复,Keras将暂停训练以每100步运行一次验证。

tldr: leave it in.

https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/tensorflow/python/data/ops/dataset_ops.py L134-L3445

https://docs.python.org/3/library/exceptions.html

相关内容

  • 没有找到相关文章

最新更新