LSTM模型拟合问题



我有一个非常简单的LSTM模型,定义为

def get_lstm_model(shape_input, num_output):
model = Sequential([
layers.Input((shape_input, num_output)),
layers.LSTM(64),
layers.Dense(32, activation = 'relu'),
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics = ['accuracy'])
return model

模型定义适用于下列

model_mlp = get_lstm_model(8,5)
model_mlp.summary()

现在,当我拟合模型时,我得到了对这条线的错误响应

model_history = model_mlp.fit(x_train, y_train, validation_split=0.2,
epochs=500, batch_size=5000)

我得到的错误是:

"Input 0 of layer "lstm_3" is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (None, 8)"
For clarity, the shape of x_train is (2134,8), while the shape of y_train is (2134,5)

如有任何帮助,我们将不胜感激。

输入维度内部确实存在问题。我不确定您的数据代表什么,所以我为所需形状的x_trainy_train创建了两个伪numpy数组。

据我所知,你想要一个5的输出形状,所以你必须在最后一层中指定它。

关于输入形状,主要问题是LSTM层期望三维的输入形状。来自LSTM:的文档

输入:具有形状[批次,时间步长,特征]的3D张量

因此x_train必须具有3D形状:样本数量+两个数字,如我下面所示。

代码:

import numpy as np
from tensorflow.keras import layers, Sequential
import tensorflow as tf

def get_lstm_model(shape_input, num_output):
model = keras.Sequential()
model.add(layers.LSTM(64, input_shape=(None, 8)))
model.add(layers.Dense(32, activation = 'relu'))
model.add(layers.Dense(5))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics = ['accuracy'])
return model
model_mlp = get_lstm_model(8,5)
print(model_mlp.summary())
x_train = np.zeros((2134, 8, 8))
y_train = np.zeros((2134, 5))
model_history = model_mlp.fit(x_train, y_train, validation_split=0.2, epochs=2, batch_size=5000)

总结:

Model: "sequential_40"
_________________________________________________________________
Layer (type)                Output Shape              Param #   
=================================================================
lstm_33 (LSTM)              (None, 64)                18688     

dense_43 (Dense)            (None, 32)                2080      

dense_44 (Dense)            (None, 5)                 165       

=================================================================
Total params: 20,933
Trainable params: 20,933
Non-trainable params: 0
_________________________________________________________________

最新更新