如何在tf 2.1.0中创建tf.data.Dataset的训练,测试和验证拆分



以下代码是从中复制的:https://www.tensorflow.org/tutorials/load_data/images

该代码旨在创建从网上下载的图像数据集,并根据它们的类存储到文件夹中,请参考上面的链接了解整个上下文!

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
for f in list_ds.take(5):
print(f.numpy())
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
# The second to last is the class-directory
return parts[-2] == CLASS_NAMES
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
img = tf.image.convert_image_dtype(img, tf.float32)
# resize the image to the desired size.
return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in labeled_ds.take(1):
print("Image shape: ", image.numpy().shape)
print("Label: ", label.numpy())
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
# This is a small dataset, only load it once, and keep it in memory.
# use `.cache(filename)` to cache preprocessing work for datasets that don't
# fit in memory.
if cache:
if isinstance(cache, str):
ds = ds.cache(cache)
else:
ds = ds.cache()
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
# Repeat forever
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches in the background while the model
# is training.
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
train_ds = prepare_for_training(labeled_ds)

我们最后剩下的是train_ds,它是PreffetchDataset对象,包含图像、标签的整个数据集!如何将train_ds分解为训练、测试和训练;验证集以将其输入到模型中?

ds.repeat()调用之后,数据集是无限的,拆分中篇数据集的效果不太好。因此,您应该在prepare_training()调用之前将其拆分。像这样:

labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
labeled_ds = labeled_ds.shuffle(10000).batch(BATCH_SIZE)
# Size of dataset
n = sum(1 for _ in labeled_ds)
n_train = int(n * 0.8)
n_valid = int(n * 0.1)
n_test = n - n_train - n_valid
train_ds = labeled_ds.take(n_train)
valid_ds = labeled_ds.skip(n_train).take(n_valid)
test_ds = labeled_ds.skip(n_train + n_valid).take(n_test)

n = sum(1 for _ in labeled_ds)行在数据集中迭代一次以获得其大小,然后将其分为80%/10%/10%。

最新更新