我正在尝试从SHAP库中绘制瀑布图,以表示像这样的模型预测的实例:
ex = shap.Explanation(shap_values[0],
explainer.expected_value,
X.iloc[0],
columns)
ex
ex返回:
.values =
array([-2.27243590e-01, 5.41666667e-02, 3.33333333e-03, 2.21153846e-02,
1.92307692e-04, -7.17948718e-02])
.base_values =
0.21923076923076923
.data =
BMI 18.716444
ROM-PADF-KE_D 33
Asym-ROM-PHIR(≥8)_discr 1
Asym_SLCMJLanding-pVGRF(10percent)_discr 1
Asym_TJ_Valgus_FPPA(10percent)_discr 1
DVJ_Valgus_KneeMedialDisplacement_D_discr 0
Name: 0, dtype: object
但是当我尝试绘制瀑布图时我收到了错误
shap.waterfall_plot(ex)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
/tmp/ipykernel_4785/3628025354.py in <module>
----> 1 shap.waterfall_plot(ex)
/usr/local/lib/python3.8/dist-packages/shap/plots/_waterfall.py in waterfall(shap_values, max_display, show)
120 yticklabels[rng[i]] = feature_names[order[i]]
121 else:
--> 122 yticklabels[rng[i]] = format_value(features[order[i]], "%0.03f") + " = " + feature_names[order[i]]
123
124 # add a last grouped feature to represent the impact of all the features we didn't show
/usr/local/lib/python3.8/dist-packages/shap/utils/_general.py in format_value(s, format_str)
232 s = format_str % s
233 s = re.sub(r'.?0+$', '', s)
--> 234 if s[0] == "-":
235 s = u"u2212" + s[1:]
236 return s
IndexError: string index out of range**strong text**
编辑最小可重复错误:
解释器是一个内核解释器:
explainer_2 = shap.KernelExplainer(sci_Model_2.predict, X)
shap_values_2 = explainer.shap_values(X)
X和y是来自dataFrames的列表:
y = data_modelo_1_2_csv_encoded['Soft-Tissue_injury_≥4days']
y_list = label_encoder.fit_transform(y)
X = data_modelo_1_2_csv_encoded.drop('Soft-Tissue_injury_≥4days',axis=1)
X_list = X.to_numpy()
和模型是python的一个小weka模型包装器,使用python库与weka模型,如SHAP,这样做:
class weka_classifier(BaseEstimator, ClassifierMixin):
def __init__(self, classifier = None, dataset = None):
if classifier is not None:
self.classifier = classifier
if dataset is not None:
self.dataset = dataset
self.dataset.class_is_last()
if index is not None:
self.index = index
def fit(self, X, y):
return self.fit2()
def fit2(self):
return self.classifier.build_classifier(self.dataset)
def predict_instance(self,x):
x.append(0.0)
inst = Instance.create_instance(x,classname='weka.core.DenseInstance', weight=1.0)
inst.dataset = self.dataset
return self.classifier.classify_instance(inst)
def predict_proba_instance(self,x):
x.append(0.0)
inst = Instance.create_instance(x,classname='weka.core.DenseInstance', weight=1.0)
inst.dataset = self.dataset
return self.classifier.distribution_for_instance(inst)
def predict_proba(self,X):
prediction = []
for i in range(X.shape[0]):
instance = []
for j in range(X.shape[1]):
instance.append(X[i][j])
instance.append(0.0)
instance = Instance.create_instance(instance,classname='weka.core.DenseInstance', weight=1.0)
instance.dataset=self.dataset
prediction.append(self.classifier.distribution_for_instance(instance))
return np.asarray(prediction)
def predict(self,X):
prediction = []
for i in range(X.shape[0]):
instance = []
for j in range(X.shape[1]):
instance.append(X[i][j])
instance.append(0.0)
instance = Instance.create_instance(instance,classname='weka.core.DenseInstance', weight=1.0)
instance.dataset=self.dataset
prediction.append(self.classifier.classify_instance(instance))
return np.asarray(prediction)
def set_data(self,dataset):
self.dataset = dataset
self.dataset.class_is_last()
数据库是加载到CSV的arff,并像数据帧一样上传,变量如下:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 260 entries, 0 to 259
Data columns (total 7 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 BMI 260 non-null float64
1 ROM-PADF-KE_D 260 non-null int64
2 Asym-ROM-PHIR(≥8)_discr 260 non-null int64
3 Asym_SLCMJLanding-pVGRF(10percent)_discr 260 non-null int64
4 Asym_TJ_Valgus_FPPA(10percent)_discr 260 non-null int64
5 DVJ_Valgus_KneeMedialDisplacement_D_discr 260 non-null int64
6 Soft-Tissue_injury_≥4days 260 non-null category
dtypes: category(1), float64(1), int64(5)
可能您的问题是.data
字段中的0
是字符串而不是数字。我可以在format_value('0', "%0.03f")
中重现相同的错误。
查看当前的format_value
,我们可以看到它从字符串中删除所有末尾的零,特别是format_value('100', "%0.03f")
给出1
。这是一个bug,应该替换正则表达式(例如:https://stackoverflow.com/a/26299205/4178189)
请注意,当您提供一个数字(例如100或0)时,该数字首先被替换为字符串(100.000
或0.000
),因此该函数在使用数字(int或float)调用时不会显示其错误。
shap
的开发版本(尚未发布)也不会受到此问题的影响,因为当使用非数字值调用waterfall_plot
函数时不会调用format_value
,参见:https://github.com/slundberg/shap/blob/8926cd0122d0a1b3cca0768f2c386de706090668/shap/plots/_waterfall.py#L127
注意:这个问题也是github的问题,参见https://github.com/slundberg/shap/issues/2581#issuecomment-1155134604