如何提高BERT keras中枢层输入的秩(ndim)以学习秩



我正在尝试使用tensorflow hub上提供的预训练BERT来实现一个学习到排序模型。我使用的是ListNet loss函数的变体,它要求每个训练实例都是与查询相关的几个排序文档的列表。我需要模型能够接受形状(batch_size、listrongize、sentence_length(的数据,其中模型在每个训练实例中的"listrongize"轴上循环,返回秩并将其传递给损失函数。在一个只由密集层组成的简单模型中,这可以通过增加输入层的维度来轻松完成。例如:

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
input = Input([6,10])
x = Dense(20,activation='relu')(input)
output = Dense(1, activation='sigmoid')(x)
model = Model(inputs=input, outputs=output)

现在,在计算损失和更新梯度之前,该模型将在长度为10的向量上执行6次前向传递。

我正在尝试对BERT模型及其预处理层做同样的事情:

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
bert_preprocess_model = hub.KerasLayer('https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1')
bert_model = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')

text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
processed_input = bert_preprocess_model(text_input)
output = bert_model(processed_input)
model = tf.keras.Model(text_input, output)

但是,当我试图将"text_input"的形状更改为(6(,或者以任何方式干预它时,它总是会导致相同类型的错误:

ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* Tensor("inputs:0", shape=(None, 6), dtype=string)
* False
* None
Keyword arguments: {}

Expected these arguments to match one of the following 4 option(s):

Option 1:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
* False
* None
Keyword arguments: {}
....

根据https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer,您似乎可以配置集线器的输入形状。KerasLayer通过tf.keras.layers.InputSpec。在我的情况下,我想它应该是这样的:

bert_preprocess_model.input_spec = tf.keras.layers.InputSpec(ndim=2)
bert_model.input_spec = tf.keras.layers.InputSpec(ndim=2)

当我运行上面的代码时,属性确实会发生更改,但当试图构建模型时,会出现完全相同的错误。

有没有任何方法可以在不需要创建自定义训练循环的情况下轻松解决此问题?

假设你有一批B个例子,每个例子都有N个文本字符串,这就形成了一个形状为[B,N]的二维张量。使用tf.reshape((,您可以将其转换为形状为[B*N]的一维张量,通过BERT(保留输入的顺序(发送,然后将其重塑为[B,N]。(还有tf.keras.layers.Reshape,但它对您隐藏了批次维度。(

如果不是每次都是N个文本字符串,则必须在侧面进行一些记账(例如,将输入存储在tf.RaggedTensor中,在其.values上运行BERT,并根据结果构造具有相同.row_splits的新RaggedTensor。(

最新更新