我有一个包含 11 个样本的数据集。当我选择BATCH_SIZE
为 2 时,以下代码将出现错误:
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
if shuffle:
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)
问题在于dataset = dataset.batch(batch_size)
,当Dataset
循环到最后一批时,剩余的样本计数仅为 1,那么有没有办法从之前访问的样本中随机选择一个并生成最后一批?
@mining通过填充文件名来提出解决方案。
另一种解决方案是使用 tf.contrib.data.batch_and_drop_remainder
.这将以固定的批大小对数据进行批处理,并删除最后一个较小的批处理。
在您的示例中,如果输入为 11 个,批量大小为 2,这将产生 5 个批次,每批 2 个元素。
以下是文档中的示例:
dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
调用batch
中设置drop_remainder=True
。
dataset = dataset.batch(batch_size, drop_remainder=True)
从文档中:
drop_remainder:(可选。A tf.bool scalar tf.张量,表示 在它有较少的情况下是否应该删除最后一批 比batch_size元素;默认行为是不删除 小批量。