Expand source code
import numpy as np

from collections import namedtuple

from sklearn.tree import BaseDecisionTree

TreeData = namedtuple('TreeData', "children_left children_right "
                                  "feature threshold n_node_samples impurity value n_classes n_outputs")


def extract_figs_tree(node, n_classes):
    tree_data = TreeData(
        children_left=[],
        children_right=[],
        feature=[],
        threshold=[],
        n_node_samples=[],
        impurity=[],
        value=[],
        n_classes=np.array([n_classes]),
        n_outputs=np.array([1]))

    node_counter = iter(range(1, int(1e06)))

    def _update_node(nd):
        if nd is None:
            return
        has_children = nd.right is not None
        left = right = -1
        feature = threshold = -2
        value = np.expand_dims(np.array([0]), axis=-1) if nd.value is None else nd.value
        impurity_reduction = 0 if nd.impurity_reduction is None else nd.impurity_reduction
        if has_children:
            right = next(node_counter)
            left = next(node_counter)
            feature = nd.feature
            threshold = nd.threshold

        tree_data.children_left.append(left)
        tree_data.children_right.append(right)
        tree_data.feature.append(feature)
        tree_data.threshold.append(threshold)
        tree_data.n_node_samples.append(np.sum(nd.idxs))
        tree_data.impurity.append(impurity_reduction)
        tree_data.value.append(np.array(value))

        _update_node(nd.right)
        _update_node(nd.left)

    _update_node(node)
    return tree_data


class LightTreeViz:
    def __init__(self, figs_tree, n_classes):
        tree = extract_figs_tree(figs_tree, n_classes)
        self.children_left = tree.children_left
        self.children_right = tree.children_right
        self.feature = tree.feature
        self.threshold = tree.threshold
        self.n_node_samples = tree.n_node_samples
        self.impurity = tree.impurity
        self.value = tree.value
        self.n_classes = tree.n_classes
        self.n_outputs = tree.n_outputs


class DecisionTreeViz(BaseDecisionTree):
    def __init__(self, dt, criterion, n_classes):

        tree = LightTreeViz(dt, n_classes)
        self.tree_ = tree
        self.criterion = criterion

Functions

def extract_figs_tree(node, n_classes)
Expand source code
def extract_figs_tree(node, n_classes):
    tree_data = TreeData(
        children_left=[],
        children_right=[],
        feature=[],
        threshold=[],
        n_node_samples=[],
        impurity=[],
        value=[],
        n_classes=np.array([n_classes]),
        n_outputs=np.array([1]))

    node_counter = iter(range(1, int(1e06)))

    def _update_node(nd):
        if nd is None:
            return
        has_children = nd.right is not None
        left = right = -1
        feature = threshold = -2
        value = np.expand_dims(np.array([0]), axis=-1) if nd.value is None else nd.value
        impurity_reduction = 0 if nd.impurity_reduction is None else nd.impurity_reduction
        if has_children:
            right = next(node_counter)
            left = next(node_counter)
            feature = nd.feature
            threshold = nd.threshold

        tree_data.children_left.append(left)
        tree_data.children_right.append(right)
        tree_data.feature.append(feature)
        tree_data.threshold.append(threshold)
        tree_data.n_node_samples.append(np.sum(nd.idxs))
        tree_data.impurity.append(impurity_reduction)
        tree_data.value.append(np.array(value))

        _update_node(nd.right)
        _update_node(nd.left)

    _update_node(node)
    return tree_data

Classes

class DecisionTreeViz (dt, criterion, n_classes)

Base class for decision trees.

Warning: This class should not be used directly. Use derived classes instead.

Expand source code
class DecisionTreeViz(BaseDecisionTree):
    def __init__(self, dt, criterion, n_classes):

        tree = LightTreeViz(dt, n_classes)
        self.tree_ = tree
        self.criterion = criterion

Ancestors

  • sklearn.tree._classes.BaseDecisionTree
  • sklearn.base.MultiOutputMixin
  • sklearn.base.BaseEstimator
class LightTreeViz (figs_tree, n_classes)
Expand source code
class LightTreeViz:
    def __init__(self, figs_tree, n_classes):
        tree = extract_figs_tree(figs_tree, n_classes)
        self.children_left = tree.children_left
        self.children_right = tree.children_right
        self.feature = tree.feature
        self.threshold = tree.threshold
        self.n_node_samples = tree.n_node_samples
        self.impurity = tree.impurity
        self.value = tree.value
        self.n_classes = tree.n_classes
        self.n_outputs = tree.n_outputs
class TreeData (children_left, children_right, feature, threshold, n_node_samples, impurity, value, n_classes, n_outputs)

TreeData(children_left, children_right, feature, threshold, n_node_samples, impurity, value, n_classes, n_outputs)

Ancestors

  • builtins.tuple

Instance variables

var children_left

Alias for field number 0

var children_right

Alias for field number 1

var feature

Alias for field number 2

var impurity

Alias for field number 5

var n_classes

Alias for field number 7

var n_node_samples

Alias for field number 4

var n_outputs

Alias for field number 8

var threshold

Alias for field number 3

var value

Alias for field number 6