我偶然发现了这本关于预测的笔记本。我是通过这篇文章得到它的。
我对
下面的第2行和第4行感到困惑train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.cache().shuffle(buffer_size).batch(batch_size).repeat()
val_data = tf.data.Dataset.from_tensor_slices((x_vali, y_vali))
val_data = val_data.batch(batch_size).repeat()
我理解我们正在尝试洗牌我们的数据,因为我们不想以串行顺序向我们的模型提供数据。在额外的阅读中,我意识到最好让buffer_size
与数据集的大小相同。但我不确定repeat
在这种情况下做什么。谁能解释一下这里正在做什么,repeat
的作用是什么?
我也看了这一页,看到下面的文字,但仍然不清楚。
The following methods in tf.Dataset :
repeat( count=0 ) The method repeats the dataset count number of times.
shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.
不传递任何参数给count的repeat调用使这个数据集无限重复。
在python术语中,数据集是python可迭代对象的子类。如果您有一个tf.data.Dataset
类型的对象ds
,那么您可以执行iter(ds)
。如果数据集是由repeat()
生成的,那么它永远不会耗尽项目,也就是说,它永远不会抛出StopIteration
异常。
在您引用的笔记本中,将对tf.keras.Model.fit()的调用传递给参数steps_per_epoch
的参数100
。这意味着数据集应该无限重复,Keras将暂停训练以每100步运行一次验证。
tldr: leave it in.
https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/tensorflow/python/data/ops/dataset_ops.py L134-L3445
https://docs.python.org/3/library/exceptions.html