在熊猫中比较绘制不同二元分类器的 ROC 曲线的最简单方法是什么?



我有三个二元分类模型,我到达了以下一点,试图将它们组装成最终的比较ROC图。

import pandas as pd
import numpy as np
import sklearn.metrics as metrics
y_test = ... # a numpy array containing the test values
dfo    = ... # a pd.DataFrame containing the model predictions
dfroc = dfo[['SVM',
'RF',
'NN']].apply(lambda y_pred: metrics.roc_curve(y_test[:-1], y_pred[:-1])[0:2], 
axis=0, result_type='reduce')
print(dfroc)
dfroc_auc = dfroc.apply(lambda x: metrics.auc(x[0], x[1]))
print(dfroc_auc)

输出以下内容(其中dfrocdfroc_auc的类型为pandas.core.series.Series(:

SVM     ([0.0, 0.016666666666666666, 1.0], [0.0, 0.923...
RF      ([0.0, 0.058333333333333334, 1.0], [0.0, 0.769...
NN      ([0.0, 0.06666666666666667, 1.0], [0.0, 1.0, 1...
dtype: object
SVM     0.953205
RF      0.855449
NN      0.966667
dtype: float64

为了能够将它们绘制为比较 ROC,我需要将它们转换为以下枢轴结构,如dfrocpd.DataFrame......如何实现这种枢轴化?

model   fpr       tpr
1     SVM     0.0       0.0        
2     SVM     0.16666   0.923
3     SVM     1.0       ...
4     RF      0.0       0.0       
5     RF      0.05833   0.769
6     RF      1.0       ... 
7     NN      ...       ...

然后,对于绘图和遵循说明,如何在 Python 中绘制 ROC 曲线将是这样的:

import matplotlib.pyplot as plt
plt.title('Receiver Operating Characteristic')
dfroc.plot(label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()

不是理想的结构,但假设您有如下内容:

s = pd.Series({'SVC':([0.0, 0.016, 1.0], [0.0, 0.923, 0.5], [0.3, 0.4, 0.9]),
'RF': ([0.0, 0.058, 1.0], [0.0, 0.923, 0.2], [0.5, 0.3, 0.9]),
'NN': ([0.0, 0.06,  1.0], [0.0, 0.13, 0.4], [0.2, 0.4, 0.9])})

您可以定义一个函数来计算TPRFPR,并返回具有指定结构的数据帧:

def tpr_fpr(g):
model, cm = g
cm = np.stack(cm.values)
diag = np.diag(cm)
FP = cm.sum(0) - diag   
FN = cm.sum(1) - diag
TP = diag
TN = cm.sum() - (FP + FN + TP)
TPR = TP/(TP+FN)
FPR = FP/(FP+TN)
return pd.DataFrame({'model':model,
'TPR':TPR, 
'FPR':FPR})

并从第一级的groupby,并将上述功能应用于每个组:

out = pd.concat([tpr_fpr(g) for g in s.explode().groupby(level=0)])
<小时 />
print(out)
model       TPR       FPR
0    NN  0.000000  0.098522
1    NN  0.245283  0.179688
2    NN  0.600000  0.880503
0    RF  0.000000  0.177117
1    RF  0.821906  0.129804
2    RF  0.529412  0.550206
0   SVC  0.000000  0.099239
1   SVC  0.648630  0.159021
2   SVC  0.562500  0.615006

最新更新