scikitplot->IndexError:数组的索引太多:数组是一维的,但有2个被索引



谁能帮我解决这个错误?

y_test['cEXT'].shape ,y_test['cEXT'].ndim # returns ((982,), 1)
Y_test_probs.shape,Y_test_probs.ndim # returns ((982,), 1)
# AOC ROC Curve
import scikitplot as skplt
Y_test_probs = np.squeeze(model_cEXT.predict(X_test))
skplt.metrics.plot_roc_curve(y_test['cEXT'], Y_test_probs,
title="Digits ROC Curve", figsize=(12,6));

错误:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-136-33407ee6fd1a> in <module>()
4 
5 skplt.metrics.plot_roc_curve(y_test['cEXT'], Y_test_probs,
----> 6                        title="Digits ROC Curve", figsize=(12,6));
1 frames
/usr/local/lib/python3.7/dist-packages/scikitplot/metrics.py in plot_roc_curve(y_true, y_probas, title, curves, ax, figsize, cmap, title_fontsize, text_fontsize)
255     roc_auc = dict()
256     for i in range(len(classes)):
--> 257         fpr[i], tpr[i], _ = roc_curve(y_true, probas[:, i],
258                                       pos_label=classes[i])
259         roc_auc[i] = auc(fpr[i], tpr[i])
IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

似乎y_probas需要两个维度(n_samples, n_classes)

也许你可以试着添加一个维度:

np。expand_dims (Y_test_probs, 1)

最新更新