在训练模型时显示错误
检查输入时出错:预期lstm_22_input具有 3 个维度,但得到具有形状的数组 (15, 33297(
我已经尝试并更改了input_shape 100 次,但最后它显示此错误。
我已经使用expand_dims多次更改了input_dims我也转换了它,但它也显示了相同的错误。
from keras.layers import Embedding
model=Sequential()
model.add(LSTM(50,return_sequences=True, input_shape=(X_train.shape[0],
X_train.shape[1],)))
model.add(LSTM(32, return_sequences=True ))
model.add(Dense(2, activation='softmax'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
print("Train...")
model.fit(X_test,y_test,batch_size=5, epochs=10)
错误:
ValueError Traceback (most recent call last)
<ipython-input-89-afa0c9eaa4e6> in <module>()
1 print("Train...")
----> 2 model.fit(X_test,y_test,batch_size=5, epochs=10)
~Anaconda3libsite-packageskerasenginetraining.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
950 sample_weight=sample_weight,
951 class_weight=class_weight,
--> 952 batch_size=batch_size)
953 # Prepare validation data.
954 do_validation = False
~Anaconda3libsite-packageskerasenginetraining.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
749 feed_input_shapes,
750 check_batch_axis=False, # Don't enforce the batch size.
--> 751 exception_prefix='input')
752
753 if y is not None:
~Anaconda3libsite-packageskerasenginetraining_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
126 ': expected ' + names[i] + ' to have ' +
127 str(len(shape)) + ' dimensions, but got array '
--> 128 'with shape ' + str(data_shape))
129 if not check_batch_axis:
130 data_shape = data_shape[1:]
ValueError: Error when checking input: expected lstm_22_input to have 3 dimensions, but got array with shape (15, 33297)
问题是在这一行
model.add(LSTM(50,return_sequences=True, input_shape=(X_train.shape[0],
X_train.shape[1],)))
您指定 3D 输入,因为后面的逗号
X_train.shape[1],
删除逗号以获取 2D 输入或将其更改为
model.add(LSTM(50,return_sequences=True, input_shape=(X_train.shape)))
以确保输入尺寸匹配。