我正在编写交通标志识别程序。我使用图像大小为32x32像素的数据。我创建了一个深度CNN模型,并在另一个文件中调用它进行分类。
def classify(file_path):
global label_packed
image = imread(file_path)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, size=[32, 32])
image = np.expand_dims(image, axis=0)
pred = model.predict(image)
sign = labels[np.argmax(pred)]
print(sign)
label.configure(foreground='#011638', text=sign)
def show_classify_button(file_path):
classify_b = Button(top, text="Classify Image",
command=lambda: classify(file_path),
padx=10, pady=5)
classify_b.configure(background='#364156', foreground='white',
font=('arial', 10, 'bold'))
classify_b.place(relx=0.79, rely=0.46)
def upload_image():
try:
file_path = filedialog.askopenfilename()
uploaded = Image.open(file_path)
uploaded.thumbnail(((top.winfo_width()/2.25),
(top.winfo_height()/2.25)))
im = ImageTk.PhotoImage(uploaded)
sign_image.configure(image=im)
sign_image.image = im
label.configure(text='')
show_classify_button(file_path)
except:
pass
,当我点击分类按钮,我得到这个错误:
错误信息:
Exception in Tkinter callback
Traceback (most recent call last):
File "C:UsersMonsterAnaconda3libtkinter__init.py", line 1705, in call__
return self.func(*args)
File "<ipython-input-113-c3d169264d01>", line 2, in <lambda>
classify_b=Button(top,text="Classify Image",command=lambda: classify(file_path),padx=10,pady=5)
File "<ipython-input-112-b328d3eae35f>", line 7, in classify
pred = model.predict_classes([image])[0]
File "C:UsersMonsterAnaconda3libsite-packagestensorflowpythonkerasenginesequential.py", line 318, in predict_classes
proba = self.predict(x, batch_size=batch_size, verbose=verbose)
File "C:UsersMonsterAnaconda3libsite-packagestensorflowpythonkerasenginetraining.py", line 1060, in predict
x, check_steps=True, steps_name='steps', steps=steps)
File "C:UsersMonsterAnaconda3libsite-packagestensorflowpythonkerasenginetraining.py", line 2651, in _standardize_user_data
exception_prefix='input')
File "C:UsersMonsterAnaconda3libsite-packagestensorflowpythonkerasenginetraining_utils.py", line 385, in standardize_input_data
str(data_shape))
ValueError: Error when checking input: expected conv2d_3_input to have shape (32, 32, 1) but got array with shape (32, 32, 3)
您将RGB图像加载到期望灰度图像的模型中。我建议你使用tf.image.rgb_to_grayscale
来解决这个问题。
# ...
image = np.expand_dims(image, axis=0)
image = tf.image.rgb_to_grayscale(image)
pred = model.predict(image)
# ...