TensorFlow在迭代数据集时显示进度



我有一个非常大的数据集(原始文件~750GB),我使用TensorFlow数据API创建了一个缓存数据集管道,如下所示:

dataset = tf.data.Dataset.from_generator(MSUMFSD(pathlib.Path(dataset_locations["mfsd"]), True), output_types=(tf.string, tf.float32))

这个数据集包含我想要用于处理的所有文件路径。之后,我使用这个interleave转换,为我的模型生成实际的输入数据:

class DatasetTransformer:
def __init__(self, haar_cascade_path, window, img_shape):
self.rppg_extractor = RPPGExtractionChrom(haar_cascade_path, window, img_shape)
self.window = window
self.img_shape = img_shape
def generator(self, file, label):
for signal, frame in self.rppg_extractor.process_file_iter(file.decode()):
yield (signal, frame), [label]
def __call__(self, file, label):
output_signature = (
(
tensorflow.TensorSpec(shape=(self.window), dtype=tensorflow.float32),
tensorflow.TensorSpec(shape=(self.img_shape[0], self.img_shape[1], 3), dtype=tensorflow.float32)
),
tensorflow.TensorSpec(shape=(1), dtype=tensorflow.float32))
return tensorflow.data.Dataset.from_generator(self.generator, args=(file, label), output_signature=output_signature)

dataset = dataset.interleave(
DatasetTransformer("rppg/haarcascade_frontalface_default.xml", window_size, img_shape),
num_parallel_calls=tf.data.AUTOTUNE
)
dataset = dataset.prefetch(tf.data.AUTOTUNE).shuffle(320).cache(cache_filename)

现在我想对数据集进行一次迭代,以创建缓存的数据集(由模型的实际输入组成)并获得数据集大小。是否有一种方法可以显示迭代的进度?我的尝试是在像这样的交错转换之前获得文件的数量:

dataset_file_amount = dataset.reduce(0, lambda x,_: x + 1).numpy()

,然后显示一个使用TQDM的进度条,同时迭代遍历像这样的数据集:

def dataset_reducer(x, pbar):
pbar.update()
return x + 1
pbar = tqdm(total=dataset_file_amount, desc="Preprocessing files...")
size = dataset.reduce(0, lambda x,_: dataset_reducer(x, pbar)).numpy()

当运行这段代码时,我得到一个包含正确总数(文件数)的进度条,但进度条没有更新。一旦处理完成,它就会停留在0%,它只是继续执行。您知道如何显示(至少对于已处理的文件的数量)预处理的进度吗?谢谢了!

编辑实际上,进度条卡在1/X而不是0%。

我通过不更新reduce函数内部的进度条修复了这个问题。我将pbar对象传递给DatasetTransformer类,并在生成方法的for循环之后更新进度。这将基于已处理的文件更新进度(我每个文件提取几百帧,现在我得到了已经处理了多少文件的进度):

class DatasetTransformer:
def __init__(self, haar_cascade_path, window, img_shape, progress):
self.rppg_extractor = RPPGExtractionChrom(haar_cascade_path, window, img_shape)
self.window = window
self.img_shape = img_shape
self.progress = progress
def generator(self, file, label):
rppg_extractor = RPPGExtractionChrom(self.haar_cascade_path, self.window, self.img_shape)
for signal, frame in rppg_extractor.process_file_iter(file.decode()):
yield (signal, frame), [label]
self.progress.update(1) # <- Update Progress bar here
def __call__(self, file, label):
output_signature = (
(
tensorflow.TensorSpec(shape=(self.window), dtype=tensorflow.float32),
tensorflow.TensorSpec(shape=(self.img_shape[0], self.img_shape[1], 3), dtype=tensorflow.float32)
),
tensorflow.TensorSpec(shape=(1), dtype=tensorflow.float32))
return tensorflow.data.Dataset.from_generator(self.generator, args=(file, label), output_signature=output_signature)
dataset = dataset.interleave(
DatasetTransformer("rppg/haarcascade_frontalface_default.xml", window_size, img_shape),
num_parallel_calls=tf.data.AUTOTUNE
)
pbar = tqdm(total=dataset_file_amount, desc="Preprocessing files...")
dataset = dataset.interleave(
DatasetTransformer("rppg/haarcascade_frontalface_default.xml", window_size, img_shape, pbar),
num_parallel_calls=tf.data.AUTOTUNE
)

相关内容

  • 没有找到相关文章

最新更新