r语言 - 基于随机森林的多分类问题的近似SHAP值



我想使用fastshap包来获得使用随机森林分类器的多分类问题中每个类别的结果的SHAP值图。我只能找到一些代码块,但没有解释如何从一开始就获得这种情况下的SHAP值。以下是我到目前为止的代码(我的y有5个类,这里我试图获得类3的SHAP值):

library(randomForest)
library(fastshap)
set.seed(42) 
sample <- sample.int(n = nrow(ITA), size = floor(.75*nrow(ITA)), replace=F)
train <- ITA [sample,]
test <- ITA [-sample,]
set.seed(42)
rftrain <-randomForest(y ~ ., data=train, ntree=500, importance = TRUE) 
p_function_3<- function(object, newdata) 
caret::predict.train(object, 
newdata = newdata, 
type = "prob")[,3]
shap_values_G <- fastshap::explain(rftrain, 
X = train, 
pred_wrapper = p_function_3, 
nsim = 50,
newdata=train[which(y==3),])

现在,我把代码大部分来自我在网上找到的一个例子,我试图适应它(我不是一个专家R用户),但它不起作用。你能帮我改一下吗?谢谢!

这是一个工作示例(使用不同的数据集),但我认为逻辑是相同的。

library(randomForest)
library(fastshap)
set.seed(42) 
ix <- sample(nrow(iris), 0.75 * nrow(iris))
train <- iris[ix, ]
test <- iris[-ix, ]
xvars <- c("Sepal.Width", "Sepal.Length")
yvar <- "Species"
fit <- randomForest(reformulate(xvars, yvar), data = train, ntree = 500) 
pred_3 <- function(model, newdata) {
predict(model, newdata = newdata, type = "prob")[, "virginica"]
}
shap_values_3 <- fastshap::explain(
fit, 
X = train,             # Reference data
feature_names = xvars,
pred_wrapper = pred_3, 
nsim = 50,
newdata = train[train$Species == "virginica", ] # For these rows, you will calculate explanations
)
head(shap_values_3)
# Sepal.Width Sepal.Length
# <dbl>        <dbl>
# 1      0.101        0.381 
# 2      0.159       -0.0109
# 3      0.0736      -0.0285
# 4      0.0564       0.161 
# 5      0.0649       0.594 
# 6      0.232        0.0305

相关内容

  • 没有找到相关文章

最新更新