如何在 tensorflow/keras 中从下载的 tar.gz 文件中加载数据?



Tensorflow 数据集或 tfds 会自动开始下载我想要的数据。我的系统中下载了 cifar10。我可以使用以下命令直接在 pytorch 中加载数据: torchvision.datasets.CIFAR10('path/to/directory',...,download=False(

是否有与此等效的张量流或 keras?

我认为你能做的最好的事情就是首先提取tar文件:

import tarfile
if fname.endswith("tar.gz"):
tar = tarfile.open(fname, "r:gz")
tar.extractall()
tar.close()
elif fname.endswith("tar"):
tar = tarfile.open(fname, "r:")
tar.extractall()
tar.close()

然后访问模型数据并使用 keras 加载它:

https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model

发布另一种从本地加载文件的方法,供以后找到它的人使用。

在 URL 参数中,提供文件的本地 URL 和路径。例如,我的 D 驱动器中文件夹 Workspace/DataFiles/tldr.gz 中有一个文件,那么我为 URL 参数给出的值将是这样的。

path = 'file:///D:/Workspace/DataFiles/tldr.gz'
path_to_downloaded_file = tf.keras.utils.get_file("tldr_data",path, archive_format='tar', untar=True)`

通过这种方式,keras 识别 URL 并从文件中加载数据。

如何使用 tf.keras.utils.get_file(( 函数获取使用HTTP(S( 下载的文件的提取tar.gz的文件名/目录:

import os
from tensorflow import keras
# `fname` in the following function is an equivalent to
# `wget`'s `-O` parameter. It is needed so that the download
# function know under which name to save the downloaded file.
file_path = keras.utils.get_file(fname='ucf101_top5.tar.gz',
origin='https://zenodo.org/record/7924745/files/ucf101_top5.tar.gz',
extract=True)
keras_datasets_dir = os.path.dirname(file_path)
# Assuming that you know of the contents of the archive to be extracted:
dataset_dir = os.path.join(keras_datasets_dir, 'ucf101_top5')
print(dataset_dir)

示例输出:

'/Users/mbuch/.keras/datasets/ucf101_top5'

当然,您必须知道要下载的文件的提取目录/内容的名称才能使用此示例。

最新更新