我有一个输入数据集(让我们将其命名为ds),一个传递给编码器(命名为embedder
的模型)的函数。我想做一个编码的数据集,并保存到文件。我想做的:
转换器功能:
def generate_embedding(image, label, embedder):
return (embedder(image)[0], label)
转换:
embedding_ds = ds.map(lambda image, label: generate_embedding(image, label, embedder), num_parallel_calls=tf.data.AUTOTUNE)
拯救:
embedding_ds.save(path)
但我有embedding_ds
的问题,它不是tf.data.Dataset
(这是我所期望的),但tf.raw_ops.ParallelMapDataset
,它没有保存方法。谁能给点建议?
看起来这个问题存在于我的tensorflow版本(2.9.2),而不存在于2.11
可能更新?在2.11.0中,它可以工作:
import tensorflow as tf
ds = tf.data.Dataset.range(5)
tf.__version__ # 2.11.0
ds = ds.map(lambda e : (e + 3) % 5, num_parallel_calls=3)
ds.save('test') # works