r-随机森林回归:提取每棵树终端节点的训练样本



我想实现Bertsimas等人的预测处方方法。(2020(,他们将机器学习方法与优化相结合。为此,我需要查看林中每个决策树的终端节点(离散区域(。

具体来说,我想知道每棵树的以下内容:

  1. 训练样本落在哪个区域
  2. 测试样本属于哪个地区

我希望我的问题通过以下一个决策树的图片变得更清楚:

回归树示例

这里,对于第一终端节点,我对预测m不感兴趣,而是对形成预测基础的值y1、y4和y5感兴趣。


完美的结果将是类似矩阵的结构,其中每列表示一棵树,每行表示一个训练(测试(样本。对于每个样本和树,结构应该给我可以找到样本的区域/终端节点的ID!

我看了randomForestranger包,但没有找到任何相关的东西。。。一些论文提到了用caret包实现该方法,但没有提到如何绕过预测。

下面是一个使用ranger:的可重复回归示例

library(MASS)
library(e1071)
library(ranger)
#load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))
train <- Boston[ind == 1,]
test <- Boston[ind == 2,]
#train random forest
boston.rf <- ranger(medv ~ ., data = train) 

非常感谢您的帮助。干杯

到目前为止,我发现获得这些信息的一种方法是使用带有选项keep.inbag=TrandomForest包,这允许您检索创建每棵树所使用的样本的信息,以及检索林中每棵树的树结构的方法getTree

我创建了一个函数来从getTree中检索给定树结构的终端节点id。

# function to retrieve the terminal node id given a rf tree structure and a sample (with numerical only features)
get_terminal_node_id_for_sample <- function(tree, sample){
node_id=1
search <- TRUE
while(search){
if(tree$status[node_id]=="-1"){
search <- FALSE
break
}
if(sample[as.character(tree$split.var[node_id])] < tree$split.point[node_id]){
node_id <- as.numeric(tree$left.daughter[node_id])
} else {
node_id <- as.numeric(tree$right.daughter[node_id])
}
}
return(node_id)
}

并像这样使用:

library(randomForest)
library(MASS)
library(e1071)
# load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))
train <- Boston[ind == 1,]
test <- Boston[ind == 2,]
# train random forest and keep inbag information
model = randomForest(medv~.,data = train,
keep.inbag=T)
# get the first tree of the forest
treeind <- 1
tree <- data.frame(getTree(model, k=treeind, labelVar=TRUE))
# loop over each sample in inbag of the first tree
for (sampleind in which(model$inbag[,treeind]>0)){
sample <- train[sampleind,]
node_id <- get_terminal_node_id_for_sample(tree,sample)

##########################
# do whatever with node_id
##########################

print(paste("sample",sampleind,"is in terminal node",node_id,sep=" "))
}

需要一提的是:我只测试了数字特征。

最新更新