找不到可以处理输入的数据适配器:<类'tensorflow.python.keras.preprocessing.image.ImageDataGenerator'>,<类'N



我尝试了CNN的数据增强,但得到一个错误"未能找到可以处理输入的数据适配器:<类'tensorflow.python.>, <类'NoneType'>> "。有人能帮帮我吗?

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
(x_train, y_train) , (x_test, y_test) = datasets.cifar10.load_data()
img_width, img_height, img_num_channels = 32, 32, 3
x_train = x_train.astype('float32')       
x_test = x_test.astype('float32')
x_train /= 255.0              
x_test /= 255.0
input_shape = (img_width, img_height, img_num_channels)
CNN used
train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=45, shear_range=0.2, 
                                   zoom_range=0.2, horizontal_flip=True)
validation_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow(x_train, y_train, batch_size=100)
batch_size=100
model.fit_generator(train_generator,
                    steps_per_epoch = 2000 //batch_size, epochs =2,
                    validation_data = validation_datagen,
                    validation_steps = 800 //batch_size)
score = model.evaluate(x_test, y_test) 
print('Test loss:', score[0]) 
print('Test accuracy:', score[1])
predictions = model.predict([x_test])
#print(predictions)
print(np.argmax(predictions[0]))
img_path = x_test[0]
print(img_path.shape)
if(len(img_path.shape) == 3):
    plt.imshow(np.squeeze(img_path))
elif(len(img_path.shape) == 2):
    plt.imshow(img_path)
else:
    print("Image cannot be shown")

您没有定义验证生成器。添加这个:

validation_generator = validation_datagen.flow(x_test, y_test, batch_size=100)

和培训中:

model.fit_generator(train_generator,
                steps_per_epoch = 2000 //batch_size, epochs =2,
                validation_data = validation_generator, #change this
                validation_steps = 800 //batch_size)

按照Kaveh的回答做。此外,您需要重新缩放像素两次。所以删除

x_train /= 255.0              
x_test /= 255.0

相关内容

  • 没有找到相关文章

最新更新