如何将一个网络的输出连接到Keras的另一个网络输入



我有两个具有给定架构的网络:

hidden_state = 256
embedding_size = 128
# Encoder
enc_input = Input(shape=(max_fr_len,), name='enc_input')
x = Embedding(en_vocab, embedding_size)(enc_input)
x = GRU(hidden_state, return_sequences=True)(x)
x = GRU(hidden_state, return_sequences=True)(x)
enc_output = GRU(hidden_state, return_sequences=False)(x)

# Decoder
dec_input_seq = Input(shape=(max_en_len), name='dec_input_seq')
dec_hidden_state = Input(shape=(hidden_state,), name='dec_hidden_state')
x = Embedding(en_vocab, embedding_size)(dec_input_seq)
x = GRU(hidden_state, return_sequences=True)(x, initial_state=dec_hidden_state)
x = GRU(hidden_state, return_sequences=True)(x, initial_state=dec_hidden_state)
x = GRU(hidden_state, return_sequences=True)(x, initial_state=dec_hidden_state)
dec_output = Dense(en_vocab, activation='softmax')(x)

我想创建另一个网络,该网络与另外两个名为model的网络共享参数,使得dec_hidden_state连接到enc_output,从而将隐藏状态传递到GRU单元
为了清楚起见,我想创建一个模型,作为其输入,有enc_inputdec_input_seq(因为dec_hidden_state已经连接(,并输出dec_output

感谢您的帮助。

啊,是的,一如既往地提供了愉快的帮助,谢谢。不管怎样,对于任何想知道我找到了解决方案的人来说。这不是我的想法,但我决定张贴它。

您只需要实例化层并创建两个单独的函数来稍后连接它,而不需要即时连接它。完整代码如下。

# Encoder
enc_input = Input(shape=(None,), name='enc_input')
enc_embedding = Embedding(fr_vocab, embedding_size)
enc_gru1 = GRU(hidden_state, return_sequences=True)
enc_gru2 = GRU(hidden_state, return_sequences=True)
enc_gru3 = GRU(hidden_state, return_sequences=False)

# Decoder
dec_input_seq = Input(shape=(None,), name='dec_input_seq')
dec_hidden_state = Input(shape=(hidden_state,), name='dec_hidden_state')
dec_embedding = Embedding(en_vocab, embedding_size)
dec_gru1 = GRU(hidden_state, return_sequences=True)
dec_gru2 = GRU(hidden_state, return_sequences=True)
dec_gru3 = GRU(hidden_state, return_sequences=True)
dec_dense = Dense(en_vocab, activation='softmax', name='decoder_output')
def connect_encoder():
x = enc_input
x = enc_embedding(x)
x = enc_gru1(x)
x = enc_gru2(x)
x = enc_gru3(x)
return x
def connect_decoder(hidden_state):
x = dec_input_seq
x = dec_embedding(x)
x = dec_gru1(x, initial_state=hidden_state)
x = dec_gru2(x, initial_state=hidden_state) 
x = dec_gru3(x, initial_state=hidden_state)
x = dec_dense(x)
return x
encoder_output = connect_encoder()
decoder_output = connect_decoder(dec_hidden_state)
encoder = Model(inputs=enc_input, outputs=encoder_output)
decoder = Model(inputs=[dec_input_seq, dec_hidden_state], outputs=decoder_output)
decoder_output = connect_decoder(encoder_output)
model = Model(inputs=[enc_input, dec_input_seq], outputs=decoder_output)

最新更新