r-找到每个节点partykit的路径/规则



是否可以找到每个节点的路径/规则?我想提取每个节点的规则,而不仅仅是终端节点的规则。

示例:

library(partykit)
X    <- MASS::mvrnorm(n, rep(0, p), diag(p))
y    <- as.numeric(drop(X %*% rep(1, p)) > 2)
data <- data.frame(y, X)
tree      <- rpart(y ~ .,
data = data,
control = rpart.control(cp = 0.005))
pfit      <- as.party(tree)

我可以使用partykit:::.list.rules.party(pfit),但这会返回终端节点的规则。我正在查找每个节点的规则。

设置参数i = ...以指定需要规则的所有节点ID。使用nodeids(),您可以提取所有节点ID(默认情况下(:

R> partykit:::.list.rules.party(pfit, i = nodeids(pfit))
    1
   "" 
    2 
"X3 < 0.650618460125409" 
    3 
"X3 < 0.650618460125409 & X2 < 1.62837615944647" 
    4 
"X3 < 0.650618460125409 & X2 < 1.62837615944647 & X4 < 1.38485264813313" 
...

最新更新