我想训练一个CNN,它可以对CIFAR10数据库中的图像进行分类。根据Keras以前的任务,代码应该像一样工作
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from random import sample
import tensorflow as tf
from tensorflow import keras
cat_dict = {0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'}
def assign_class(val):
isclass = cat_dict[int(val)]
return isclass
def show_imgs(X,Y):
plt.figure(1, figsize=(20,20))
k = 0
for i in range(0,5):
for j in range(0,5):
plt.subplot2grid((5,5),(i,j))
plt.imshow(X[k], cmap='gray')
plt.title(assign_class(Y[k]))
k = k+1
plt.axis('off')
# show the plot
plt.show()
# Load data & split data between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
show_imgs(x_train, y_train)
print('Shape of training data:', x_train.shape)
print('Shape of test data:', x_test.shape)
# Normalization
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.0
x_test /= 255.0
#One-hot encode
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
print(x_train.shape[0], 'training samples')
print(x_test.shape[0], 'test samples')
print(y_train.shape[0], 'training label samples')
print(y_test.shape[0], 'test label samples')
E = 5 #epochs
B = 128 #batch size
n_classes = 10
cnn = keras.models.Sequential()
cnn.add(keras.layers.Conv2D(filters = 32, kernel_size = (2,2), input_shape = (32,32,3), padding = 'valid', strides = (1,1) ))
cnn.add(keras.layers.Activation('relu'))
cnn.add(keras.layers.MaxPooling2D(pool_size = (2, 2)))
cnn.add(keras.layers.Flatten())
cnn.add(keras.layers.Dense(n_classes, input_shape = (3072,), activation='softmax'))
cnn.summary()
cnn.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
log = cnn.fit(x_train, y_train, batch_size = B, epochs = E, validation_data = (x_test, y_test), verbose = 1)
我希望.fit()
开始适应模型,但我得到的唯一输出是"Epoch 1/5〃;其他什么都没有(除了我的电脑开始听起来像A747(。
重新安装了一个全新的环境,没有问题