我在Spark中有一个RandomForestClassifierModel。使用.toDebugString()输出如下
Tree 0 (weight 1.0):
If (feature 0 in {1.0,2.0,3.0})
If (feature 3 in {2.0,3.0})
If (feature 8 <= 55.3)
.
.
Else (feature 0 not in {1.0,2.0,3.0})
.
.
Tree 1 (weight 1.0):
.
.
...etc
我想在模型中查看实际数据,例如
Tree 0 (weight 1.0):
If (feature 0 in {1.0,2.0,3.0}) 60%
If (feature 3 in {2.0,3.0}) 57%
If (feature 8 <= 55.3) 22%
.
.
Else (feature 0 not in {1.0,2.0,3.0}) 40%
.
.
Tree 1 (weight 1.0):
.
...etc
通过查看每个节点中标签的概率,我可以看到数据(数千条记录)最有可能在树中遵循哪些路径,这将是非常好的洞察力!
我在这里找到了一个很棒的答案:Spark MLib决策树:按特征标记的概率?
不幸的是,答案中的方法使用MLlib API,经过多次尝试,我未能使用DataFrame API复制它,DataFrame API具有不同的实现类Node和Split:(
昨天我发现有用的一种方法是,我可以使用spark.read.parquet()函数从模型/数据文件读取输出。这样,关于某个节点的所有信息都可以作为整个数据框架来检索。
`val modelPath = "some/path/to/your/model"
val dataPath = modelPath + "/data"
val nodeData: DataFrame = spark.read.parquet(dataPath)
nodeData.show(500,false)
nodeData.printSchema()`
那么你可以用信息重建树。