无论我的训练集有多小,测试准确性始终很高



我正在做一个项目,我试图将评论分为不同的类别:"有毒","severe_toxic","淫秽","侮辱","identity_hate"。我正在使用的数据集来自这个Kaggle挑战:https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge。我目前面临的问题是,无论我将数据拟合的训练数据集有多小,当我预测测试数据的标签时,我的准确性总是在 90% 左右或更高。在本例中,我正在训练 15 行数据并测试 159,556 行。我通常会为具有高测试准确性而兴奋,但在这种情况下,我觉得我做错了什么。

我正在将数据读入熊猫数据帧:

trainData = pd.read_csv('train.csv')

以下是打印数据时的外观:

id                                       comment_text  
0       0000997932d777bf  ExplanationnWhy the edits made under my usern...   
1       000103f0d9cfb60f  D'aww! He matches this background colour I'm s...   
2       000113f07ec002fd  Hey man, I'm really not trying to edit war. It...   
3       0001b41b1c6bb37e  "nMorenI can't make any real suggestions on ...   
4       0001d958c54c6e35  You, sir, are my hero. Any chance you remember...   
...                  ...                                                ...   
159566  ffe987279560d7ff  ":::::And for the second time of asking, when ...   
159567  ffea4adeee384e90  You should be ashamed of yourself nnThat is ...   
159568  ffee36eab5c267c9  Spitzer nnUmm, theres no actual article for ...   
159569  fff125370e4aaaf3  And it looks like it was actually you who put ...   
159570  fff46fc426af1f9a  "nAnd ... I really don't think you understand...   
toxic  severe_toxic  obscene  threat  insult  identity_hate  
0           0             0        0       0       0              0  
1           0             0        0       0       0              0  
2           0             0        0       0       0              0  
3           0             0        0       0       0              0  
4           0             0        0       0       0              0  
...       ...           ...      ...     ...     ...            ...  
159566      0             0        0       0       0              0  
159567      0             0        0       0       0              0  
159568      0             0        0       0       0              0  
159569      0             0        0       0       0              0  
159570      0             0        0       0       0              0  
[159571 rows x 8 columns]

然后,我使用train_test_split将数据拆分为训练和测试:

X = trainData.drop(labels= ['id','toxic','severe_toxic','obscene','threat','insult','identity_hate'],axis=1)
Y = trainData.drop(labels = ['id','comment_text'],axis=1)
trainX,testX,trainY,testY = train_test_split(X,Y,test_size=0.9999,random_state=99)

我正在使用sklearn的HashingVectorizer将注释转换为数字向量以进行分类:

def hashVec():
trainComments=[]
testComments=[]
for index,row in trainX.iterrows():
trainComments.append(row['comment_text'])
for index,row in testX.iterrows():
testComments.append(row['comment_text'])
vectorizer = HashingVectorizer()
trainSamples = vectorizer.transform(trainComments)
testSamples = vectorizer.transform(testComments)
return trainSamples,testSamples

我正在使用 sklearn 的 OneVsRestClassifier 和 LogisticRegression 来拟合和预测 6 个类中每个类的数据

def logRegOVR(trainSamples,testSamples):
commentTypes=['toxic','severe_toxic','obscene','threat','insult','identity_hate']
clf = OneVsRestClassifier(LogisticRegression(solver='sag'))
for cType in commentTypes:
print(cType,":")
clf.fit(trainSamples,trainY[cType])
pred1 = clf.predict(trainSamples)
print("tTrain Accuracy:",accuracy_score(trainY[cType],pred1))
prediction = clf.predict(testSamples)
print("tTest Accuracy:",accuracy_score(testY[cType],prediction))

最后,这里是我调用函数和我得到的输出的地方:

sol = hashVec()
logRegOVR(sol[0],sol[1])
toxic :
Train Accuracy: 0.8666666666666667
Test Accuracy: 0.9041590413397177
severe_toxic :
Train Accuracy: 1.0
Test Accuracy: 0.9900035097395272
obscene :
Train Accuracy: 1.0
Test Accuracy: 0.9470468048835519
threat :
Train Accuracy: 1.0
Test Accuracy: 0.9970041866178646
insult :
Train Accuracy: 1.0
Test Accuracy: 0.9506317531148938
identity_hate :
Train Accuracy: 1.0
Test Accuracy: 0.9911943142219659

当我有 80% 训练和 20% 测试的更合理train_test_split时,测试准确性非常相似。

感谢您的帮助

你没有使用一个好的指标:准确性不是确定你是否做对的好方法。我建议您查看我们所谓的 F1 分数,它是精度和召回率之间的混合,我发现它与评估我的分类器的工作方式更相关

如果它是一个不平衡的数据集,准确性并不意味着什么。 如果 90% 的数据集注释不属于任何这些"有毒"类别,并且模型始终预测注释是"干净的",那么您仍然具有 90% 的准确率。

最新更新