我有一个从一些熊猫数据创建的张量流 2.0tf.data.Dataset
。现在我想更改数据集上的设置,但它似乎不允许我这样做。例如,我想将数据集上的.repeat()
参数从无限重复更改为仅重复 1 次。但是当我尝试进行此更改时,数据集不接受更改。
下面是一个包含一些代码的示例。该函数取自 TF 网站上的一个 Tensorflow 教程。
URL = 'https://storage.googleapis.com/applied-dl/heart.csv'
df = pd.read_csv(URL)
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
dataframe = dataframe.copy()
labels = dataframe.pop('target')
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
if shuffle:
ds = ds.shuffle(buffer_size=len(dataframe))
ds = ds.batch(batch_size).repeat() # <-- NOTICE THE INFINITE REPEAT
return ds
train_ds = df_to_dataset(df)
train_ds.repeat(1) # <-- TRYING TO CHANGE TO A FIXED NUMBER OF REPETITIONS
所以我尝试更改数据集上的重复次数,但这会导致数据集仍然永远重复。这就像如果我将数据集设置为无限重复,然后重复 1 次,我最终得到一个无限重复但 1 次的数据集——这与无限重复相同。
我想同样的行为可能也适用于数据集的其他特征,例如批次数等。
有没有办法重置数据集上的行为?
我认为这是预期的行为。一旦它是一个无限重复的数据集,再重复一次仍然会产生一个有无限例子的数据集。
您可能可以做dataset.take(count)
,其中count
是等于原始示例的批次数的数字,前提是您对数据进行了很好的洗牌。