我想使用 sklearn.tree 获取节点进行预测的所有信息。
例如:
from sklearn.datasets import load_iris
nfrom sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
现在我们可以使用以下方法预测类:
clf.predict(iris.data[0, :])
如何获取进行预测的叶节点以及存储在叶中的信息?
我知道上面例子的树的图形表示如下:
http://scikit-learn.org/stable/modules/tree.html#tree-classification
所以我知道对应于输入 iris.data[0, :](第一个左子项)的节点具有以下统计信息:
- 错误 = 0
- 样本=50
- 值 = [50 0 0]
是否可以在不打印树的情况下自动获取输出节点和(上述)信息?从我目前的理解来看,关键是获取进行预测的叶节点的 ID,然后将相关统计数据包含在 clf.tree_.value[ID] 和 clf.tree_.n_samples[ID] 中。
谢谢
看看这个问题。它说明了如何获取叶子的ID。然后你可以使用这些clf.tree_.value
和clf.tree.n_samples
。