如何从Tensorflow中提取"图像"和"标签"?



我已经从CIFAR10加载了训练集和验证集,如下所示:

train = tfds.load('cifar10', split='train[:90%]', shuffle_files=True)
validation = tfds.load('cifar10', split='train[-10%:]', shuffle_files=True)

我已经为我的CNN创建了架构

model = ...

现在我试图使用model.fit()来训练我的模型,但我不知道如何从我的对象中分离出"图像"one_answers"标签"。训练和验证如下所示:

print(train) # same layout as the validation set
<_OptionsDataset shapes: {id: (), image: (32, 32, 3), label: ()}, types: {id: tf.string, image: tf.uint8, label: tf.int64}>

我天真的做法是这样的,但这些OptionsDatasets是不可下标的。

history = model.fit(train['image'], train['label'], epochs=100, batch_size=64, validation_data=(validation['image'], test['label'], verbose=0)

我们可以这样做

import tensorflow as tf
import tensorflow_datasets as tfds
def normalize(img, label):
img = tf.cast(img, tf.float32) / 255.
return (img, label)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
ds = ds.map(normalize)
for i in ds.take(1):
print(i[0].shape, i[1].shape)
# (32, 28, 28, 1) (32,)
  • 使用as_supervised=True返回image,label作为元组
  • 使用.map()进行预处理甚至增强。

模型

# declare input shape 
input = tf.keras.Input(shape=(28,28,1))
# Block 1
x = tf.keras.layers.Conv2D(32, 3, strides=2, activation="relu")(input)
# Now that we apply global max pooling.
gap = tf.keras.layers.GlobalMaxPooling2D()(x)
# Finally, we add a classification layer.
output = tf.keras.layers.Dense(10, activation='softmax')(gap)
# bind all
func_model = tf.keras.Model(input, output)

编译并运行

print('nFunctional API')
func_model.compile(
metrics=['accuracy'],
loss= 'sparse_categorical_crossentropy', # labels are integer (not one-hot)
optimizer = tf.keras.optimizers.Adam()
)
func_model.fit(ds)
# 1875/1875 [==============================] - 15s 7ms/step - loss: 2.1782 - accuracy: 0.2280

Tensorflow知道如何处理tfds对象。你可以直接输入

history = model.fit(train, epochs=100, batch_size=64, validation_data=(validation, verbose=0)

不需要从图像中分离出标签。但是如果你真的想,你可以这样做

labels = []
for image, label in tfds.as_numpy(train):
labels.append(label)

相关内容

  • 没有找到相关文章

最新更新