Module imodels.util.convert
Expand source code
from typing import Union, List
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.tree import _tree
def tree_to_rules(tree: Union[DecisionTreeClassifier, DecisionTreeRegressor],
feature_names: List[str]) -> List[str]:
"""
Return a list of rules from a tree
Parameters
----------
tree : Decision Tree Classifier/Regressor
feature_names: list of variable names
Returns
-------
rules : list of rules.
"""
# XXX todo: check the case where tree is build on subset of features,
# ie max_features != None
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
rules = []
def recurse(node, base_name):
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
symbol = '<='
symbol2 = '>'
threshold = tree_.threshold[node]
text = base_name + ["{} {} {}".format(name, symbol, threshold)]
recurse(tree_.children_left[node], text)
text = base_name + ["{} {} {}".format(name, symbol2,
threshold)]
recurse(tree_.children_right[node], text)
else:
rule = str.join(' and ', base_name)
rule = (rule if rule != ''
else ' == '.join([feature_names[0]] * 2))
# a rule selecting all is set to "c0==c0"
rules.append(rule)
recurse(0, [])
return rules if len(rules) > 0 else 'True'
Functions
def tree_to_rules(tree, feature_names)
-
Return a list of rules from a tree
Parameters
tree : Decision Tree Classifier/Regressor feature_names: list of variable names
Returns
rules : list of rules.
Expand source code
def tree_to_rules(tree: Union[DecisionTreeClassifier, DecisionTreeRegressor], feature_names: List[str]) -> List[str]: """ Return a list of rules from a tree Parameters ---------- tree : Decision Tree Classifier/Regressor feature_names: list of variable names Returns ------- rules : list of rules. """ # XXX todo: check the case where tree is build on subset of features, # ie max_features != None tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] rules = [] def recurse(node, base_name): if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] symbol = '<=' symbol2 = '>' threshold = tree_.threshold[node] text = base_name + ["{} {} {}".format(name, symbol, threshold)] recurse(tree_.children_left[node], text) text = base_name + ["{} {} {}".format(name, symbol2, threshold)] recurse(tree_.children_right[node], text) else: rule = str.join(' and ', base_name) rule = (rule if rule != '' else ' == '.join([feature_names[0]] * 2)) # a rule selecting all is set to "c0==c0" rules.append(rule) recurse(0, []) return rules if len(rules) > 0 else 'True'