如何在Spark (scala)中查看随机森林统计信息



我在Spark中有一个RandomForestClassifierModel。使用.toDebugString()输出如下

Tree 0 (weight 1.0):
  If (feature 0 in {1.0,2.0,3.0})
   If (feature 3 in {2.0,3.0})
    If (feature 8 <= 55.3) 
.
.
  Else (feature 0 not in {1.0,2.0,3.0})
.
.
Tree 1 (weight 1.0):
.
.
...etc

我想在模型中查看实际数据,例如

Tree 0 (weight 1.0):
  If (feature 0 in {1.0,2.0,3.0}) 60%
   If (feature 3 in {2.0,3.0}) 57%
    If (feature 8 <= 55.3) 22%
.
.
  Else (feature 0 not in {1.0,2.0,3.0}) 40%
.
.
Tree 1 (weight 1.0):
.
...etc

通过查看每个节点中标签的概率,我可以看到数据(数千条记录)最有可能在树中遵循哪些路径,这将是非常好的洞察力!

我在这里找到了一个很棒的答案:Spark MLib决策树:按特征标记的概率?

不幸的是,答案中的方法使用MLlib API,经过多次尝试,我未能使用DataFrame API复制它,DataFrame API具有不同的实现类Node和Split:(

昨天我发现有用的一种方法是,我可以使用spark.read.parquet()函数从模型/数据文件读取输出。这样,关于某个节点的所有信息都可以作为整个数据框架来检索。

`val modelPath = "some/path/to/your/model"
val dataPath = modelPath + "/data"    
val nodeData: DataFrame = spark.read.parquet(dataPath)
nodeData.show(500,false)
nodeData.printSchema()`

那么你可以用信息重建树。

相关内容

  • 没有找到相关文章