我是Tensorflow的新手,我正在尝试实现一个反语检测模型。我的数据集由标记为1或0的tweet组成,以指示它们是否具有讽刺意味。
在预处理、标记化和填充阶段之后,我留下了固定长度的序列和一个相关的标签向量,在训练和测试集中进行分割,并将其作为模型的输入。这些序列的形式如下:
>>> data
array([[ 1, 677, 348, ..., 0, 0, 0],
[ 1, 677, 348, ..., 0, 0, 0],
[ 1, 825, 1, ..., 0, 0, 0],
...,
[ 908, 1376, 686, ..., 0, 0, 0],
[ 8, 158, 14579, ..., 0, 0, 0],
[ 1, 1, 35, ..., 0, 0, 0]], dtype=int32)
>>> data.shape
(3977, 50)
>>> data[0].shape
(50,)
模型如下:
num_words = len(tok.word_index) + 1 # tok is a Tokenizer which I fit on the data
import tensorflow as tk
from tensorflow import keras
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)
# The model
model = keras.Sequential()
model.add(keras.layers.Embedding(num_words, 64, input_length=Config.SEQUENCE_LENGTH, mask_zero=True))
model.add(keras.layers.GRU(64, return_sequences=True))
model.add(keras.layers.GRU(64))
model.add(keras.layers.Dense(1, activation='sigmoid'))
用model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
编译模型,用sklearn
的效用函数分割数据集后,我调用模型的拟合方法:
model.fit(x_train, y_train, batch_size=10, epochs=10, validation_split=0.1, callbacks=[early_stopping])
在训练模型后,evaluate
方法按预期工作,将x_test
和y_test
作为输入,但如果我调用model.predict_classes(x_test[0])
(或(model.predict(x_test[0]) > 0.5).astype("int32")
)而不是单个预测,我将得到一个(50,1)形状的预测数组。我尝试用这种方式重塑x_test[0],model.predict_classes(x_test[0].reshape(1,50))
,我在数组中得到一个预测:array([[1]], dtype=int32)
所以现在我留下了以下问题(也由于我在调用evaluate(x_test, y_test)
时得到0.6的精度:
- 为什么如果我将数据集传递给模型作为数组(x_train)的数组,我不能只是将测试集的元素传递给预测函数(例如
x_test[0]
),但我必须重塑它? - 是否正常还是有什么错误?我是否设置了模型的输入尺寸错误?在将序列输入模型之前,我还应该重塑它们吗?
为什么我将数据集作为数组的数组传递给模型(x_train),我不能将测试集的一个元素传递给预测函数(例如x_test[0]),但我必须重塑它吗?
因为model.predict
只能接受一组示例而不是单个示例,如果你想在其中提供单个示例,你必须将其重塑为(1,50)
,因此它是一组例如,集合的大小为1。
但是,与其将一组示例逐一输入model.predict
,不如将一组示例输入model.predict
,即pred_test = model.predict(x_test)
,然后如果想知道第i个示例的预测结果,则执行prediction_of_the_ith_example = pred_test[i]
是正常还是有错误?我是在设置输入尺寸吗模型错了吗?我要在喂食前重新排列吗他们是模特吗?
模型定义和训练部分正确。
对于标签(y_train
和y_test
)的形状,正确使用的形状为(data_size,label_length),即分别为(3181,1)
和(796,1)
。这里你做的是单标签分类,所以标签长度是1。但即使你使用(3181,)
和(796,)
,它也能正常工作,因为它为你做了自动广播。