如何在Tensorflow中遍历TensorSliceDataset对象



我使用CIFAR-100数据集创建了tensorflowDataset对象。我需要访问Dataset对象内的列车标签TensorSpec。由于TensorSliceDataset对象不支持索引,因此无法通过索引进行访问。如何访问每个TensorSpec并对其中的值进行迭代。

(train_data, train_labels), (test_data, test_labels) = cifar100.load_data(label_mode='fine')
with open('data/cifar100/cifar100_labels.json', 'r') as j:
cifar_labels = json.load(j)
dataset = tf.data.Dataset.from_tensor_slices((train_data,train_labels))
print(train_dataset.element_spec)
# (TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 
# TensorSpec(shape=(1,), dtype=tf.int64, name=None))

您可以将标签转换为一个数组:

import tensorflow as tf
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train_data,train_labels))
next(dataset.batch(60_000).as_numpy_iterator())[1]
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

这就是你要找的吗?

最新更新