r语言 - 如何用BRT形状构建部分依赖图?



我没有在R中练习,我需要一些帮助。我是一名生态学家,我有一个按变量矩阵划分的网站,其中"TD0"、"TD1"、"TD2"是响应变量,"Chao"、"age"、"slope"是解释变量。

ID.plot TD0    TD1    TD2      Chao age slope
1 GS_Ci01N  20  8.898  6.488 0.6521390  26     2
2 GS_Ci03N  26  7.788  4.883 0.2335441  26     2
3 GS_Ci04N  31 10.482  7.282 0.5234748  26     0
4 GS_Ci05N  47 18.108 11.989 0.3110385  26     3
5 GS_Ci06N  47 16.332 10.107 0.4529010  26     0
6 GS_Ci07N  31  9.478  5.725 0.5524426  26     1
db.chao <- read.table(text=db.chao, header = TRUE)

我构建了提升回归树 (BRT( 来定义解释响应形状的阈值。我使用了"dismo"和"gbm"包。

mod0 <- gbm.step(data=db.chao, gbm.x = 5:7, gbm.y = 2, family = "poisson", tree.complexity = 5, learning.rate = 0.0025, bag.fraction = 0.5(

mod1 <- gbm.step(data=db.chao, gbm.x = 5:7, gbm.y = 3, family = "Gaussian", tree.complexity = 5, learning.rate = 0.0025, bag.fraction = 0.5(

mod2 <- gbm.step(data=db.chao, gbm.x = 5:7, gbm.y = 4, family = "gaussian", tree.complexity = 5, learning.rate = 0.0025, bag.fraction = 0.5(

我得到了三个模型:

  • "mod0",描述 TD0 和解释变量之间的关系
  • "mod1",描述 TD1 和解释变量之间的关系
  • "mod2",描述 TD2 与解释变量之间的关系

对于它们中的每一个,我都构建了如下所示的面板图(这些只是示例(: 在此处输入图像描述

在此处输入图像描述

在此处输入图像描述

对于每个响应变量,我有三个图表,每个解释变量一个。

我用这个脚本获得了它们:

gbm.plot(mod.TD0, n.plots = 3, write.title= FALSE, main = "TD0", rug = T, smooth = TRUE, plot.layout=c(1,3), common.scale = T)
gbm.plot(mod.TD1, n.plots = 3, write.title= FALSE, main = "TD1", rug = T, smooth = TRUE, plot.layout=c(1,3), common.scale = T)
gbm.plot(mod.TD2, n.plots = 3, write.title= FALSE, main = "TD2", rug = T, smooth = TRUE, plot.layout=c(1,3), common.scale = T)

实际上,我想要三个图表,每个解释变量一个,之后,如果可能的话,在每个图表中,我想重叠三个响应变量形状(具有三条不同的线或颜色(。

我想我应该使用"pdp"包来构建部分依赖图,但我无法做到这一点。

如果有人能帮助我,我将不胜感激。

谢谢!

我不太确定 gbm 是如何工作的,以及为什么它需要树的数量来预测输出,但这里有一个使用pdpgridExtra包的工作示例:

library(pdp)
ntrees <- 250 # Number of trees to use to predict data
pred <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees)
mean(pred)
}
pdps1 <- pdps2 <- pdps3 <- list()
for (i in 1:3) {
pdps1[[i]] <- partial(mod0, pred.var = names(db.chao)[i+4], 
train = db.chao, plot = TRUE, 
pred.fun = pred, recursive = F)
pdps2[[i]] <- partial(mod1, pred.var = names(db.chao)[i+4], 
train = db.chao, plot = TRUE,
pred.fun = pred, recursive = F)
pdps3[[i]] <- partial(mod2, pred.var = names(db.chao)[i+4], 
train = db.chao, plot = TRUE,
pred.fun = pred, recursive = F)
}
gridExtra::grid.arrange(grobs = pdps1, nrow = 1) # For the first model
gridExtra::grid.arrange(grobs = pdps2, nrow = 1) # For the second model
gridExtra::grid.arrange(grobs = pdps3, nrow = 1) # For the third model

希望这有帮助!

编辑按照OP的要求,仅在三个地块中获取所有pdps,并使用不同数量的树来预测值:

library(pdp)
ntrees1 <- 150 # Number of trees to use to predict data with model1
ntrees2 <- 250 # Number of trees to use to predict data with model2
ntrees3 <- 50 # Number of trees to use to predict data with model3
pred1 <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees1)
mean(pred)
}
pred2 <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees2)
mean(pred)
}
pred3 <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees3)
mean(pred)
}
# Function to obtain legend to plot later in grid.arrange
get_legend<-function(myggplot){
tmp <- ggplot_gtable(ggplot_build(myggplot))
leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
legend <- tmp$grobs[[leg]]
return(legend)
}
# Obtain partial dependence data instead of plot
pdps1 <- pdps2 <- pdps3 <- list()
plotlist <- list()
for (i in 1:3) {
# Create local environment to prevent ggplot to overwrite the plots with the iterator 
local({
i <- i
pdps1[[i]] <<- partial(mod0, pred.var = names(db.chao)[i+4], 
train = db.chao, plot = FALSE, 
pred.fun = pred1, recursive = F)
pdps2[[i]] <<- partial(mod1, pred.var = names(db.chao)[i+4], 
train = db.chao, plot = FALSE,
pred.fun = pred2, recursive = F)
pdps3[[i]] <<- partial(mod2, pred.var = names(db.chao)[i+4], 
train = db.chao, plot = FALSE,
pred.fun = pred3, recursive = F)
pdp <- rbind(pdps1[[i]],pdps2[[i]],pdps3[[i]])
pdp <- cbind(pdp,rep(c("y1","y2","y3"), each = nrow(pdps1[[i]])))
names(pdp)[3] <- "#output"
plotlist[[i]] <<- ggplot(pdp) +
geom_line(aes(x = pdp[,1], y = pdp[,2], 
group = pdp[,3], color = pdp[,3])) +
xlab(names(pdp)[1]) + ylab("yhat") + 
ggtitle(paste0("PDP of ",names(pdp)[1])) +
labs(color = "#output")
})
legend <- get_legend(plotlist[[i]])
plotlist[[i]] <- plotlist[[i]] + theme(legend.position = "none")
}
plotlist[[4]] <- legend
gridExtra::grid.arrange(grobs = plotlist, nrow = 1, widths=c(2.3, 2.3, 2.3, 0.8))

最新更新