我正试图使用Bert预训练模型构建一个文本分类模型,但在尝试拟合模型时,我一直遇到错误。
错误显示
ValueError: Layer "model_1" expects 2 inputs but it received only 1 input tensor.
Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 309) dtype=int32>]
我也在使用TensorFlow和其他Python库。
这是我的代码:
import numpy as np
from data_helpers import load_data
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.layers import Embedding
from sklearn.model_selection import train_test_split
from keras.layers.convolutional import Conv1D
from keras.layers.convolutional import MaxPooling1D
from keras.layers import Dropout,Flatten
from sklearn.metrics import classification_report
from transformers import TFBertModel
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow.keras.layers import Embedding
# Data Preparation
print("Load data...")
x, y, vocabulary, vocabulary_inv = load_data()
np.save('data1-vocab.npy', vocabulary)
sequence_length = x.shape[1]
X_train, X_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=42)
bert_model = TFBertModel.from_pretrained('bert-base-uncased')
def create_model(bert_model, max_len=sequence_length):
##params###
opt = tf.keras.optimizers.Adam(learning_rate=1e-5, decay=1e-7)
loss = tf.keras.losses.CategoricalCrossentropy()
accuracy = tf.keras.metrics.CategoricalAccuracy()
input_ids = tf.keras.Input(shape=(max_len,),dtype='int32')
attention_masks = tf.keras.Input(shape=(max_len,),dtype='int32')
embeddings = bert_model([input_ids,attention_masks])[1]
output = tf.keras.layers.Dense(3, activation="softmax")(embeddings)
model = tf.keras.models.Model(inputs = [input_ids,attention_masks], outputs = output)
model.compile(opt, loss=loss, metrics=accuracy)
return model
model = create_model(bert_model,sequence_length)
model.summary()
model.fit(X_train, y_train, epochs=32, batch_size=32,verbose=1)
我已经更改了.fit((函数的参数,但什么都不起作用
很明显,您的代码是:
model = tf.keras.models.Model(inputs = [input_ids,attention_masks], outputs = output)
因此,这里的输入需要2个输入(input_ids
和attention_masks
(,但在fit
函数中,您只向模型传递1个输入:
model.fit(X_train, y_train, epochs=32, batch_size=32,verbose=1)
因此,在修复该错误之前,您应该了解更多关于该模型的信息。我的意思是,你需要知道你的模型期望什么,以及你的模型的输入或输出结构。