TensorFlow数据集数据预处理用于整个数据集或每次拨打iterator.next()的调用



嗨,我现在正在研究TensorFlow中的数据集API,我对DataSet.map((函数有一个疑问,该函数执行数据预处理。

file_name = ["image1.jpg", "image2.jpg", ......]
im_dataset = tf.data.Dataset.from_tensor_slices(file_names)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser(), [image], [tf.float32, tf.float32, tf.float32])))
im_dataset = im_dataset.batch(batch_size)
iterator = im_dataset.make_initializable_iterator()

数据集接收图像名称并将其解析为3个张量(3个有关图像的信息(。

如果我的培训文件夹中有大量的图像,那么对它们进行预处理将需要很长时间。我的问题是,由于据说数据集API是为了有效的输入管道而设计的,因此在我向工人喂给工人之前,对整个数据集进行了预处理(例如,GPU(,或者每次我每次仅预处理一批图像调用iterator.get_next((?

如果您的预处理管道非常长并且输出很小,则处理后的数据应适合内存。如果是这种情况,您可以使用tf.data.Dataset.cache在内存或文件中缓存处理的数据。

摘自官方绩效指南:

tf.data.Dataset.cache转换可以在内存或本地存储中缓存数据集。如果将用户定义的函数传递到地图变换很昂贵,则只要所得数据集仍然适合内存或本地存储,就将地图转换后的缓存转换应用。如果用户定义的功能增加了存储数据集超过缓存容量所需的空间,请考虑在培训工作之前对数据进行预处理以减少资源使用。


在内存中使用缓存的示例

这是一个示例,每个预处理需要大量时间(0.5s(。数据集上的第二个时期将比第一个时代快得多

def my_fn(x):
    time.sleep(0.5)
    return x
def parse_fn(x):
    return tf.py_func(my_fn, [x], tf.int64)
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(parse_fn)
dataset = dataset.cache()    # cache the processed dataset, so every input will be processed once
dataset = dataset.repeat(2)  # repeat for multiple epochs
res = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
    for i in range(10):
        # First 5 iterations will take 0.5s each, last 5 will not
        print(sess.run(res))

缓存到文件

如果要将缓存的数据写入文件,则可以向cache()提供参数:

dataset = dataset.cache('/tmp/cache')  # will write cached data to a file

这将允许您仅处理一次数据集,并在数据上运行多个实验而不再次重新处理。

警告:在缓存到文件时,您必须小心。如果更改数据,但请保留/tmp/cache.*文件,它仍然会读取缓存的旧数据。例如,如果我们使用上面的数据并更改[10, 15]中的数据范围,我们仍将在[0, 5]中获取数据:

dataset = tf.data.Dataset.range(10, 15)
dataset = dataset.map(parse_fn)
dataset = dataset.cache('/tmp/cache')
dataset = dataset.repeat(2)  # repeat for multiple epochs
res = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
    for i in range(10):
        print(sess.run(res))  # will still be in [0, 5]...

始终 delete 每当您要缓存的数据更改时,缓存的文件。

可能出现的另一个问题是,如果您在所有数据被缓存之前中断脚本。您将收到这样的错误:

AmandExistSistError(有关追溯性,请参见上文(:似乎有一个并发的缓存迭代器运行-CACHE LOCKFILE已经存在('/tmp/cache.lockfile'(。如果您确定没有其他运行的TF计算使用此缓存前缀,请删除锁定文件并重新定位迭代器。

确保您让整个数据集处理整个缓存文件。

最新更新