自定义 Tensorflow 数据集的类型规范



Background

我正在尝试为文本摘要构建一个转换器模型。我的数据集是CNN每日邮报,我使用TF数据集系列API来检索它们。

示例代码:

cnn_builder = tfds.summarization.cnn_dailymail.CnnDailymail()
cnn_info = cnn_builder.info
cnn_builder.download_and_prepare()
datasets = cnn_builder.as_dataset()
train_dataset, test_dataset = datasets["train"], datasets["test"]

问题

train_dataset的类型规格是这样的。如您所见,它就像一个字典,而我想像一个元组,这样我就可以更轻松地标记每个实例。

calling train_dataset.element_spec will return the element spec
{'article': TensorSpec(shape=(), dtype=tf.string, name=None),
'highlights': TensorSpec(shape=(), dtype=tf.string, name=None)}

所需的数据集类型规范

(TensorSpec(shape=(), dtype=tf.string, name=None),
TensorSpec(shape=(), dtype=tf.string, name=None))

实验和问题

我没有找到这样的API将数据集的元素转换为另一种形式,也无法检索每个元素并连接它们。有人有什么想法吗?或者,如果使用当前表单,如何迭代数据集中的每个实例?提前感谢!

可以使用tf.data.Dataset.map方法将字典转换为元组。

dset = train_dataset.map(lambda d: (d['article'], d['highlights']))

要对此进行测试,您可以获取每个数据集中的第一项,并比较字典和元组版本。

out_orig = next(iter(train_dataset))
out_mapped = next(iter(dset))
tf.assert_equal(out_orig['article'], out_mapped[0])
tf.assert_equal(out_orig['highlights'], out_mapped[1])

以下是元素规范:

>>> dset.element_spec
(TensorSpec(shape=(), dtype=tf.string, name=None),
TensorSpec(shape=(), dtype=tf.string, name=None))

相关内容

最新更新