构建tensorflow数据集迭代器,生成具有特殊结构的批



正如我在标题中提到的,我需要具有特殊结构的批次:

1111
5555
2222

每个数字表示特征向量。因此,每个类都有N=4向量{1,2,5}(M=3(,并且批大小为NxM=12

为了完成这个任务,我使用Tensorflow数据集API和tfrecords:

  • 使用特性构建tfrecord,每个类一个文件
  • 为每个类创建Dataset实例,并初始化每个类的迭代器
  • 从迭代器列表生成第I批样本M随机迭代器,并从每个迭代器生成N特征向量
  • 然后我将功能堆叠在一起
  • 批处理就绪

我担心的是,我有数百个(可能有数千个(类,并且为每个类存储迭代器看起来不太好(从内存和性能的角度来看(。

有更好的方法吗?

如果您有按类排序的文件列表,您可以交错数据集:

import tensorflow as tf
N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)
dataset = tf.data.Dataset.from_tensor_slices(record_files)
# Consider tf.contrib.data.parallel_interleave for parallelization
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=M, block_length=N)
# Consider passing num_parallel_calls or using tf.contrib.data.map_and_batch for performance
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

编辑:

如果你也需要洗牌,你可以在交织步骤中添加它:

import tensorflow as tf
N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)
SHUFFLE_BUFFER_SIZE = 1000
dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
lambda record_file: tf.data.TFRecordDataset(record_file).shuffle(SHUFFLE_BUFFER_SIZE),
cycle_length=M, block_length=N)
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

注:如果没有更多剩余元素,interleavebatch都将产生"部分"输出(请参阅文档(。因此,如果每一批都有相同的形状和结构对你来说很重要,你就必须特别小心。至于批处理,您可以使用tf.contrib.data.batch_and_drop_remainder,但据我所知,没有类似的交错选项,因此您必须确保所有文件都有相同数量的示例,或者只将repeat添加到交错转换中。

编辑2:

我得到了一个类似于我认为你想要的东西的概念证明:

import tensorflow as tf
NUM_EXAMPLES = 12
NUM_CLASSES = 9
records = [[str(i)] * NUM_EXAMPLES for i in range(NUM_CLASSES)]
M = 3
N = 4
dataset = tf.data.Dataset.from_tensor_slices(records)
dataset = dataset.interleave(tf.data.Dataset.from_tensor_slices,
cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
lambda data: tf.data.Dataset.from_tensor_slices(
tf.split(tf.random_shuffle(
tf.reshape(data, (NUM_CLASSES, N))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N,)))
batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
b = sess.run(batch)
print(b''.join(b).decode())
except tf.errors.OutOfRangeError: break

输出:

888866663333
555544447777
222200001111
222288887777
666655553333
000044441111
888822225555
666600004444
777733331111

记录文件的等价物是这样的(假设记录是一维向量(:

import tensorflow as tf
NUM_CLASSES = 9
record_files = ['class{}.tfrecord'.format(i) for i in range(NUM_CLASSES)]
M = 3
N = 4
SHUFFLE_BUFFER_SIZE = 1000
dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
lambda file_name: tf.data.TFRecordDataset(file_name).shuffle(SHUFFLE_BUFFER_SIZE),
cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
lambda data: tf.data.Dataset.from_tensor_slices(
tf.split(tf.random_shuffle(
tf.reshape(data, (NUM_CLASSES, N, -1))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N, -1)))

这是通过每次读取每个类的N元素并对生成的块进行混洗和拆分来实现的。它假设类的数量可以被M整除,并且所有文件都有相同数量的记录。

最新更新