Tensorflow 2.0:我可以更改Tf.data.Dataset上的设置吗 - 特别是"repeat()



我有一个从一些熊猫数据创建的张量流 2.0tf.data.Dataset。现在我想更改数据集上的设置,但它似乎不允许我这样做。例如,我想将数据集上的.repeat()参数从无限重复更改为仅重复 1 次。但是当我尝试进行此更改时,数据集不接受更改。

下面是一个包含一些代码的示例。该函数取自 TF 网站上的一个 Tensorflow 教程。

URL = 'https://storage.googleapis.com/applied-dl/heart.csv'
df = pd.read_csv(URL)
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
dataframe = dataframe.copy()
labels = dataframe.pop('target')
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
if shuffle:
ds = ds.shuffle(buffer_size=len(dataframe))
ds = ds.batch(batch_size).repeat() # <-- NOTICE THE INFINITE REPEAT
return ds
train_ds = df_to_dataset(df)
train_ds.repeat(1) # <-- TRYING TO CHANGE TO A FIXED NUMBER OF REPETITIONS

所以我尝试更改数据集上的重复次数,但这会导致数据集仍然永远重复。这就像如果我将数据集设置为无限重复,然后重复 1 次,我最终得到一个无限重复但 1 次的数据集——这与无限重复相同。

我想同样的行为可能也适用于数据集的其他特征,例如批次数等。

有没有办法重置数据集上的行为?

我认为这是预期的行为。一旦它是一个无限重复的数据集,再重复一次仍然会产生一个有无限例子的数据集。

您可能可以做dataset.take(count),其中count是等于原始示例的批次数的数字,前提是您对数据进行了很好的洗牌。

相关内容

  • 没有找到相关文章