我正在为一份报告构建一个 LSTM,并想总结一下有关它的事情。但是,我已经看到了在 Keras 中构建 LSTM 的两种不同方法,它们为参数数量生成两个不同的值。
我想了解为什么参数以这种方式不同。
这个问题正确地说明了为什么这段代码
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import LSTM
model = Sequential()
model.add(LSTM(256, input_dim=4096, input_length=16))
model.summary()
结果为4457472参数。
据我所知,以下两个 LSTM 应该是相同的
m2 = Sequential()
m2.add(LSTM(1, input_dim=5, input_length=1))
m2.summary()
m3 = Sequential()
m3.add(LSTM((1),batch_input_shape=(None,5,1)))
m3.summary()
但是,m2
会产生28
参数,但m3
会产生12
参数。为什么?
对于具有 1暗光输入的 1 单元 LSTM,如何计算 5?
包含警告消息。希望对您有所帮助。
输出
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 256) 4457472
=================================================================
Total params: 4,457,472
Trainable params: 4,457,472
Non-trainable params: 0
_________________________________________________________________
Warning (from warnings module):
File "difparam.py", line 11
m2.add(LSTM(1, input_dim=5, input_length=1))
UserWarning: The `input_dim` and `input_length` arguments in recurrent layers are deprecated. Use `input_shape` instead.
Warning (from warnings module):
File "difparam.py", line 11
m2.add(LSTM(1, input_dim=5, input_length=1))
UserWarning: Update your `LSTM` call to the Keras 2 API: `LSTM(1, input_shape=(1, 5))`
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_2 (LSTM) (None, 1) 28
=================================================================
Total params: 28
Trainable params: 28
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_3 (LSTM) (None, 1) 12
=================================================================
Total params: 12
Trainable params: 12
Non-trainable params: 0
_________________________________________________________________
m2是基于 Stack Overflow 问题的信息构建的,而m3是基于 YouTube 上的这段视频构建的。
因为正确的值是input_dim = 1
和input_length = 5
。
它甚至写在您收到的警告中,其中m2
的输入形状与m3
中使用的形状不同:
用户警告:更新对 Keras 2 API 的
LSTM
调用:LSTM(1, input_shape=(1, 5))
强烈建议您使用警告中的建议。