tf.data.Dataset.padded_batch以不同的方式填充每个功能



>我有一个包含 3 个不同功能的tf.data.Dataset实例

  • label这是一个标量
  • sequence_feature这是一个标量序列
  • seq_of_seqs_feature这是序列特征序列

我正在尝试使用tf.data.Dataset.padded_batch()生成填充数据作为模型的输入 - 我想以不同的方式填充每个功能。

示例批次:

[{'label': 24,
'sequence_feature': [1, 2],
'seq_of_seqs_feature': [[11.1, 22.2],
[33.3, 44.4]]},
{'label': 32,
'sequence_feature': [3, 4, 5],
'seq_of_seqs_feature': [[55.55, 66.66]]}]

预期产出:

[{'label': 24,
'sequence_feature': [1, 2, 0],
'seq_of_seqs_feature': [[11.1, 22.2],
[33.3, 44.4]]},
{'label': 32,
'sequence_feature': [3, 4, 5],
'seq_of_seqs_feature': [[55.55, 66.66],
0.0, 0.0    ]}]

如您所见,不应填充label功能,而应通过给定批次中相应的最长条目填充sequence_featureseq_of_seqs_feature

tf.data.Dataset.padded_batch()方法允许您为生成的批处理的每个组件(功能)指定padded_shapes。例如,如果输入数据集名为ds

padded_ds = ds.padded_batch(
BATCH_SIZE,
padded_shapes={
'label': [],                          # Scalar elements, no padding.
'sequence_feature': [None],           # Vector elements, padded to longest.
'seq_of_seqs_feature': [None, None],  # Matrix elements, padded to longest
})                                        # in each dimension.

请注意,padded_shapes参数与输入数据集的元素具有相同的结构,因此在本例中,它采用一个字典,其中包含与您的要素名称匹配的键。

最新更新