我正在处理一个简单的模型,如下所示,但有点难以处理concat错误。
def build_classifier_model():
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='input1')
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
encoder_inputs = preprocessing_layer(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
outputs = encoder(encoder_inputs)
net = outputs['pooled_output']
net = tf.keras.layers.Dropout(0.1)(net)
side_input = tf.keras.layers.Input(shape=(2), dtype=tf.float32, name='input2')
print(net.shape)
print(side_input.shape)
net = tf.concat(values=[net, side_input], axis=1)
# net = tf.keras.layers.concatenate([net, side_input], axis=1)
net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net)
return tf.keras.Model(inputs=[text_input, side_input], outputs=net)
我打印了"net"one_answers"side_input"的形状,并检查它们的形状是(无,512(和(无,2(。
然而,我有concat秩错误,这表明形状是(1,512(和(2(。
(None, 512)
(None, 2)
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-241-a75a6db54831> in <module>
1 classifier_model = build_classifier_model()
2
----> 3 bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])
4
5
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
7184 def raise_from_not_ok_status(e, name):
7185 e.message += (" name: " + name if name is not None else "")
-> 7186 raise core._status_to_exception(e) from None # pylint: disable=protected-access
7187
7188
InvalidArgumentError: Exception encountered when calling layer "tf.concat_7" (type TFOpLambda).
ConcatOp : Ranks of all input tensors should match: shape[0] = [1,512] vs. shape[1] = [2] [Op:ConcatV2] name: concat
Call arguments received:
• values=['tf.Tensor(shape=(1, 512), dtype=float32)', 'tf.Tensor(shape=(2,), dtype=float32)']
• axis=1
• name=concat
我使用的样本输入是
bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])
问题出现在您的第二个输入中,即tf.reshape([0.3, 0.3], (2))
,它是[2]
大小的输入。
您的模型输入(即我所看到的side_input
(需要[None, 2]
大小的输入。
因此,解决方案是(假设您的第一个输入是[1, 512]
大小(,
tf.reshape([0.3, 0.3], (1, 2))