如何获取 Spark 决策树模型的节点信息



我想通过Spark MLlib的决策树获得有关生成的模型的每个节点的更多详细信息。我使用 API 可以获得的最接近的是print(model.toDebugString()),它返回类似以下内容(取自 PySpark 文档)

DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0

例如,如何修改 MLlib 源代码以获得每个节点的杂质和深度?(如有必要,如何在 PySpark 中调用新的 Scala 函数?

我将尝试通过描述我如何使用 PySpark 2.4.3 来补充@mostOfMajority的答案。

根节点

给定一个经过训练的决策树模型,您可以通过以下方式获取其根节点:

def _get_root_node(tree: DecisionTreeClassificationModel):
return tree._call_java('rootNode')

杂质

我们可以通过从根节点沿着树向下走来获得杂质。它的预购横向可以像这样完成:

def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
def recur(node):
if node.numDescendants() == 0:
return []
ni = node.impurity()
return (
recur(node.leftChild()) + [ni] + recur(node.rightChild())
)
return recur(_get_root_node(tree))

In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
If (feature 0 <= 6.5)
If (feature 0 <= 3.5)
Predict: 1.0
Else (feature 0 > 3.5)
If (feature 0 <= 5.0)
Predict: 0.0
Else (feature 0 > 5.0)
Predict: 1.0
Else (feature 0 > 6.5)
Predict: 0.0

In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]

不幸的是,我找不到任何直接在PySpark或Spark(Scala API)中访问节点的方法。但是有一种方法可以从根节点开始并遍历到不同的节点。

(我只是在这里提到了杂质,但对于深度,可以很容易地用subtreeDepth代替,impurity

假设决策树模型实例dt

PySpark

root = dt.call("topNode")
root.impurity() # gives the impurity of the root node

现在,如果我们看一下适用于root的方法:

dir(root)
[u'apply', u'deepCopy', u'emptyNode', u'equals', 'getClass', u'getNode', u'hashCode', u'id', 'impurity', u'impurity_$eq', u'indexToLevel', u'initializeLogIfNecessary', u'isLeaf', u'isLeaf_$eq', u'isLeftChild', u'isTraceEnabled', u'leftChildIndex', u'leftNode', u'leftNode_$eq', u'log', u'logDebug', u'logError', u'logInfo', u'logName', u'logTrace', u'logWarning', u'maxNodesInLevel', u'notify', u'notifyAll', u'numDescendants', u'org$apache$spark$internal$Logging$$log_', u'org$apache$spark$internal$Logging$$log__$eq', u'parentIndex', u'predict', u'predict_$eq', u'rightChildIndex', u'rightNode', u'rightNode_$eq', u'split', u'split_$eq', u'startIndexInLevel', u'stats', u'stats_$eq', u'subtreeDepth', u'subtreeIterator', u'subtreeToString', u'subtreeToString$default$1', u'toString', u'wait']

我们可以做到:

root.leftNode().get().impurity()

这可能会在树中更深入,例如:

root.leftNode().get().rightNode().get().impurity()

由于在应用leftNode()rightNode()之后,我们得到了一个option,应用get或getOrElseis necessary to get to the desired节点类型。

如果你想知道我是如何得到这些奇怪的方法的,我必须承认,我作弊了!,即我首先研究了Scala API:

火花

以下行与上述行完全相同,并且假设dt相同,则给出相同的结果:

val root = dt.topNode
root.impurity

我们可以做到:

root.leftNode.get.impurity

这可能会在树中更深入,例如:

root.leftNode.get.rightNode.get.impurity

最新更新