我有一组序列化TensorFlow Example协议缓冲区的TFRecord文件,每个注释有一个Example proto,下载自https://magenta.tensorflow.org/datasets/nsynth.我使用大约1 Gb的测试集来检查下面的代码,以防有人想下载它。每个示例都包含许多功能:音高、乐器。。。
读取该数据的代码是:
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
# Reading input data
dataset = tf.data.TFRecordDataset('../data/nsynth-test.tfrecord')
# Convert features into tensors
features = {
"pitch": tf.FixedLenFeature([1], dtype=tf.int64),
"audio": tf.FixedLenFeature([64000], dtype=tf.float32),
"instrument_family": tf.FixedLenFeature([1], dtype=tf.int64)}
parse_function = lambda example_proto: tf.parse_single_example(example_proto,features)
dataset = dataset.map(parse_function)
# Consuming TFRecord data.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size=3)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess.run(batch)
现在,音高在21到108之间。但我只想考虑给定音高的数据,例如音高=51。如何从整个数据集中提取这个"pitch=51"子集?或者,我该怎么做才能使迭代器只通过这个子集?
您所拥有的看起来相当不错,所缺少的只是一个过滤函数。
例如,如果你只想提取间距=51,你应该在地图功能之后添加
dataset = dataset.filter(lambda example: tf.equal(example["pitch"][0], 51))