我想通过Spark MLlib的决策树获得有关生成的模型的每个节点的更多详细信息。我使用 API 可以获得的最接近的是print(model.toDebugString())
,它返回类似以下内容(取自 PySpark 文档)
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
例如,如何修改 MLlib 源代码以获得每个节点的杂质和深度?(如有必要,如何在 PySpark 中调用新的 Scala 函数?
我将尝试通过描述我如何使用 PySpark 2.4.3 来补充@mostOfMajority的答案。
根节点
给定一个经过训练的决策树模型,您可以通过以下方式获取其根节点:
def _get_root_node(tree: DecisionTreeClassificationModel):
return tree._call_java('rootNode')
杂质
我们可以通过从根节点沿着树向下走来获得杂质。它的预购横向可以像这样完成:
def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
def recur(node):
if node.numDescendants() == 0:
return []
ni = node.impurity()
return (
recur(node.leftChild()) + [ni] + recur(node.rightChild())
)
return recur(_get_root_node(tree))
例
In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
If (feature 0 <= 6.5)
If (feature 0 <= 3.5)
Predict: 1.0
Else (feature 0 > 3.5)
If (feature 0 <= 5.0)
Predict: 0.0
Else (feature 0 > 5.0)
Predict: 1.0
Else (feature 0 > 6.5)
Predict: 0.0
In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]
不幸的是,我找不到任何直接在PySpark或Spark(Scala API)中访问节点的方法。但是有一种方法可以从根节点开始并遍历到不同的节点。
(我只是在这里提到了杂质,但对于深度,可以很容易地用subtreeDepth
代替,impurity
。
假设决策树模型实例dt
:
PySpark
root = dt.call("topNode")
root.impurity() # gives the impurity of the root node
现在,如果我们看一下适用于root
的方法:
dir(root)
[u'apply', u'deepCopy', u'emptyNode', u'equals', 'getClass', u'getNode', u'hashCode', u'id', 'impurity', u'impurity_$eq', u'indexToLevel', u'initializeLogIfNecessary', u'isLeaf', u'isLeaf_$eq', u'isLeftChild', u'isTraceEnabled', u'leftChildIndex', u'leftNode', u'leftNode_$eq', u'log', u'logDebug', u'logError', u'logInfo', u'logName', u'logTrace', u'logWarning', u'maxNodesInLevel', u'notify', u'notifyAll', u'numDescendants', u'org$apache$spark$internal$Logging$$log_', u'org$apache$spark$internal$Logging$$log__$eq', u'parentIndex', u'predict', u'predict_$eq', u'rightChildIndex', u'rightNode', u'rightNode_$eq', u'split', u'split_$eq', u'startIndexInLevel', u'stats', u'stats_$eq', u'subtreeDepth', u'subtreeIterator', u'subtreeToString', u'subtreeToString$default$1', u'toString', u'wait']
我们可以做到:
root.leftNode().get().impurity()
这可能会在树中更深入,例如:
root.leftNode().get().rightNode().get().impurity()
由于在应用leftNode()
或rightNode()
之后,我们得到了一个option
,应用get
或getOrElseis necessary to get to the desired
节点类型。
如果你想知道我是如何得到这些奇怪的方法的,我必须承认,我作弊了!,即我首先研究了Scala API:
火花
以下行与上述行完全相同,并且假设dt
相同,则给出相同的结果:
val root = dt.topNode
root.impurity
我们可以做到:
root.leftNode.get.impurity
这可能会在树中更深入,例如:
root.leftNode.get.rightNode.get.impurity