如何使用tf.data.Dataset加载和映射字典/jsons列表



我有一个数据集/记录存储在字典列表中。这本字典可能相当复杂。我想通过TensorFlow数据集API加载此列表。我该怎么做?我试过这样的东西,但它不起作用:

import tensorflow as tf
import json
LABELS_IDS = ["cat", "dog", "animal"]
def parse_record(record):
image = tf.io.read_file(record["_file"])
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf.image.random_flip_left_right(image, seed=None)
labels = []
for element in record["_categories"]:
if element in LABELS_IDS:
labels.append(LABELS_IDS.index(element))
one_hot_labels = tf.reduce_sum(tf.one_hot(labels, len(LABELS_IDS)), axis=0)
return image, one_hot_labels
records = [{"_file":"images/test.jpg", "_categories": ["cat", "animal"]}]

train_x = tf.data.Dataset.from_tensor_slices(records).map(parse_record)

编辑:

我找到了答案,你可以简单地将记录映射到不同的方法:

LABELS_IDS = ["cat", "dog", "animal"]
records = [{"_file":"images/test.jpg", "_categories": ["cat", "animal"]}]
def _load_files(records):
return [record["_file"] for record in records]
def _load_labels(records):
vectors = []
for record in records:
labels = []
for element in record["_categories"]:
if element in LABELS_IDS:
labels.append(LABELS_IDS.index(element))
one_hot = tf.reduce_sum(tf.one_hot(present, len(LABELS_IDS)), axis=0)
vectors.append(one_hot.numpy())
return vectors
def _load_data(file_path, label):
image = tf.io.read_file(file_path)
image = tf.image.decode_image(image, channels=3, expand_animations=False)
return image, label
data = (
_load_files(records),
_load_labels(records)
)
train_x = tf.data.Dataset.from_tensor_slices(data).map(_load_data)

为了社区的利益,我在这里添加@Cospel答案

LABELS_IDS = ["cat", "dog", "animal"]
records = [{"_file":"images/test.jpg", "_categories": ["cat", "animal"]}]
def _load_files(records):
return [record["_file"] for record in records]
def _load_labels(records):
vectors = []
for record in records:
labels = []
for element in record["_categories"]:
if element in LABELS_IDS:
labels.append(LABELS_IDS.index(element))
one_hot = tf.reduce_sum(tf.one_hot(present, len(LABELS_IDS)), axis=0)
vectors.append(one_hot.numpy())
return vectors
def _load_data(file_path, label):
image = tf.io.read_file(file_path)
image = tf.image.decode_image(image, channels=3, expand_animations=False)
return image, label
data = (
_load_files(records),
_load_labels(records)
)
train_x = tf.data.Dataset.from_tensor_slices(data).map(_load_data)

最新更新