我有一个包含一些LSTM层的Keras模型。我知道我可以通过get_weights()
方法得到LSTM层的权值,结果是一个由核、循环核和偏置三个元素组成的列表。
作为文档状态,每个元素包括LSTM层中4个门的权重。但是,它没有说明它们的存储顺序。例如,如果LSTM层有N个单元,则偏置向量将由4*N个元素组成。这些元素中哪一个对应于第一/第二/第三/第四门?
顺序为i, f, c, o
,分别代表输入门、遗忘门、单元门和输出门。您可以在这里获得LSTMCell
实现的信息。
lstm = LSTM(100)
lstm(np.zeros((64,10,5)))
kernel = lstm.weights[0]
w_i,w_f,w_c,w_o = tf.split(kernel,4,axis=1)
print(*(w.shape for w in (w_i,w_f,w_c,w_o)))#all are (5, 100)