如何保存ParallelMapDataset?



我有一个输入数据集(让我们将其命名为ds),一个传递给编码器(命名为embedder的模型)的函数。我想做一个编码的数据集,并保存到文件。我想做的:

转换器功能:

def generate_embedding(image, label, embedder):
return (embedder(image)[0], label)
转换:

embedding_ds = ds.map(lambda image, label: generate_embedding(image, label, embedder), num_parallel_calls=tf.data.AUTOTUNE)

拯救:

embedding_ds.save(path)

但我有embedding_ds的问题,它不是tf.data.Dataset(这是我所期望的),但tf.raw_ops.ParallelMapDataset,它没有保存方法。谁能给点建议?


看起来这个问题存在于我的tensorflow版本(2.9.2),而不存在于2.11

可能更新?在2.11.0中,它可以工作:

import tensorflow as tf
ds = tf.data.Dataset.range(5)
tf.__version__ # 2.11.0
ds = ds.map(lambda e : (e + 3) % 5, num_parallel_calls=3)
ds.save('test') # works

相关内容

  • 没有找到相关文章

最新更新