如何从张量流数据集迭代器返回相同的批次两次



我正在转换一些遗留代码以使用数据集 API - 此代码使用 feed_dict 将一个批次馈送到训练操作(实际上是三次(,然后重新计算损失以使用相同的批次显示。所以我需要一个迭代器,它返回完全相同的批次两次(或几次(。不幸的是,我似乎找不到使用张量流数据集的方法 - 这可能吗?

您可以使用Dataset.flat_map()Dataset.from_tensors()Dataset.repeat()一起重复Dataset的各个元素。例如,要重复元素两次:

NUM_REPEATS = 2
dataset = tf.data.Dataset.range(10)  # ...or the output of `.batch()`, etc.
# Repeat each element of `dataset` NUM_REPEATS times.
dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(NUM_REPEATS))

最新更新