如何处理tensorflow连接形状错误



我正在处理一个简单的模型,如下所示,但有点难以处理concat错误。

def build_classifier_model():
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='input1')
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
encoder_inputs = preprocessing_layer(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
outputs = encoder(encoder_inputs)
net = outputs['pooled_output']
net = tf.keras.layers.Dropout(0.1)(net)
side_input = tf.keras.layers.Input(shape=(2), dtype=tf.float32, name='input2')
print(net.shape)
print(side_input.shape)
net = tf.concat(values=[net, side_input], axis=1)
# net = tf.keras.layers.concatenate([net, side_input], axis=1)
net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net)
return tf.keras.Model(inputs=[text_input, side_input], outputs=net)

我打印了"net"one_answers"side_input"的形状,并检查它们的形状是(无,512(和(无,2(。

然而,我有concat秩错误,这表明形状是(1,512(和(2(。

(None, 512)
(None, 2)
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-241-a75a6db54831> in <module>
1 classifier_model = build_classifier_model()
2 
----> 3 bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])
4 
5 
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
7184 def raise_from_not_ok_status(e, name):
7185   e.message += (" name: " + name if name is not None else "")
-> 7186   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
7187 
7188 
InvalidArgumentError: Exception encountered when calling layer "tf.concat_7" (type TFOpLambda).
ConcatOp : Ranks of all input tensors should match: shape[0] = [1,512] vs. shape[1] = [2] [Op:ConcatV2] name: concat
Call arguments received:
• values=['tf.Tensor(shape=(1, 512), dtype=float32)', 'tf.Tensor(shape=(2,), dtype=float32)']
• axis=1
• name=concat

我使用的样本输入是

bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])

问题出现在您的第二个输入中,即tf.reshape([0.3, 0.3], (2)),它是[2]大小的输入。

您的模型输入(即我所看到的side_input(需要[None, 2]大小的输入。

因此,解决方案是(假设您的第一个输入是[1, 512]大小(,

tf.reshape([0.3, 0.3], (1, 2))

相关内容

  • 没有找到相关文章

最新更新