如何填充到tf.data.Dataset中的固定BATCH_SIZE



我有一个包含 11 个样本的数据集。当我选择BATCH_SIZE为 2 时,以下代码将出现错误:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

问题在于dataset = dataset.batch(batch_size),当Dataset循环到最后一批时,剩余的样本计数仅为 1,那么有没有办法从之前访问的样本中随机选择一个并生成最后一批?

@mining通过填充文件名来提出解决方案。

另一种解决方案是使用 tf.contrib.data.batch_and_drop_remainder .这将以固定的批大小对数据进行批处理,并删除最后一个较小的批处理。

在您的示例中,如果输入为 11 个,批量大小为 2,这将产生 5 个批次,每批 2 个元素。

以下是文档中的示例:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
您可以在

调用batch中设置drop_remainder=True

dataset = dataset.batch(batch_size, drop_remainder=True)

从文档中:

drop_remainder:(可选。A tf.bool scalar tf.张量,表示 在它有较少的情况下是否应该删除最后一批 比batch_size元素;默认行为是不删除 小批量。

最新更新