Tensorflow 2.0:如何从MapDataset(在阅读TFRecord之后)转换为可以输入到model.fit



我将训练和验证数据存储在两个独立的TFRecord文件中,其中存储了4个值:信号A(float32 shape(150,((、信号B(float32shape(150.((、标签(标量int64(、id(字符串(。我的阅读解析功能是:

def _parse_data_function(sample_proto):
raw_signal_description = {
'label': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
}
for key, item in SIGNALS.items():
raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(sample_proto, raw_signal_description)

其中CCD_ 1是映射信号名称->的信号形状的字典。然后,我读取了原始数据集:

training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')

并使用map解析值:

training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)

显示training_dataval_data的标题,我得到:

<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>

这几乎和预期的一样。我还检查了一些值的一致性,它们似乎是正确的。

现在,我的问题是:如何从具有类似字典结构的MapDataset中获得可以作为模型输入的内容

我的模型的输入是成对的(信号A,标签(,尽管将来我也会使用信号B。

对我来说,最简单的方法似乎是在我想要的元素上创建一个生成器。类似于:

def data_generator(mapdataset):
for sample in mapdataset:
yield (sample['Signal A'], sample['label'])

然而,使用这种方法,我失去了数据集的一些便利性,例如批处理,并且也不清楚如何对model.fitvalidation_data参数使用相同的方法。理想情况下,我只会在映射表示和数据集表示之间转换,其中它在信号A张量和标签对上迭代。

编辑:我的最终产品应该是类似于以下标题的东西:<TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)>但不一定是TensorSliceDataset

您可以在解析函数中简单地执行此操作。例如:

def _parse_data_function(sample_proto):
raw_signal_description = {
'label': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
}
for key, item in SIGNALS.items():
raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)
# Parse the input tf.Example proto using the dictionary above.
parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)
return parsed['Signal A'], parsed['label']

如果您在SIGNALS0上map此函数,您将拥有元组(signal_a, label)的数据集,而不是字典的数据集。您应该能够将其直接放入model.fit

最新更新