使用keras训练CNN后预测perticular image的类别



Training

import keras  
import numpy as np  
import matplotlib.pyplot as plt


from keras.preprocessing.image import ImageDataGenerator  


datagen= ImageDataGenerator(rotation_range=40,width_shift_range=0.2             
,height_shift_range=0.2,zoom_range=0.2,rescale=1./255.)


type(datagen)


from keras.models import Sequential  
from keras.layers import Conv2D,MaxPool2D,Flatten,Dense,Activation  
from keras.activations import relu , softmax  
from keras.losses import categorical_crossentropy  
from keras.optimizers import SGD,RMSprop  
from keras.callbacks import TensorBoard  


model=Sequential()  
model.add(Conv2D(32,(3,3),input_shape=(150,150,3),activation="relu"))  
model.add(MaxPool2D(pool_size=(2,2)))  
model.add(Conv2D(32,(3,3),activation="relu"))  
model.add(MaxPool2D(pool_size=(2,2)))  
model.add(Conv2D(64,(3,3),activation="relu"))  
model.add(MaxPool2D(pool_size=(2,2)))  

model.add(Flatten())  
model.add(Dense(1024,activation="relu"))  
model.add(Dense(512,activation="relu"))  
model.add(Dense(512,activation="relu"))  
model.add(Dense(512,activation="relu"))  
model.add(Dense(512,activation="relu"))  
model.add(Dense(512,activation="relu"))  
model.add(Dense(5,activation="softmax"))  


model.compile(loss="categorical_crossentropy" , optimizer=SGD(),metrics=["acc"])  


train_gen=datagen.flow_from_directory("/home/vishu//Desktop/basics/dataset",target_size=    
(150,150),batch_size=100)  


tb=TensorBoard(log_dir=".")  


model_history=model.fit_generator(train_gen,epochs=2)  

预测

import cv2  

CATEGORIES = ['1','2','3','4','5']  
def prepare(filepath):  
IMG_SIZE = 150  
img_array = cv2.imread(filepath, cv2.IMREAD_COLOR)  
new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))  
return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 3)   
prediction =      model.predict_classes([prepare('/home/vishu/Desktop/basics/dataset/d2.jpeg')])  
print(prediction)  

使用它后,它总是给我输出 4
我应该如何预测正确的图像类别?
在这里,我正在从文件夹中获取输入图像 我已经为 5 个类创建了 5 个文件夹,那么我应该如何预测图像的类别?

您忘记在 ImageDataGenerator 中进行重新缩放(除以 255(,这需要使用新的测试数据来完成,因此您必须在prepare函数中执行此操作。

最新更新