Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

convert.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
  1. import numpy as np
  2. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  3. from sklearn.tree import _tree
  4. from typing import Union, List, Tuple
  5. def tree_to_rules(tree: Union[DecisionTreeClassifier, DecisionTreeRegressor],
  6. feature_names: List[str],
  7. prediction_values: bool = False, round_thresholds=True) -> List[str]:
  8. """
  9. Return a list of rules from a tree
  10. Parameters
  11. ----------
  12. tree : Decision Tree Classifier/Regressor
  13. feature_names: list of variable names
  14. Returns
  15. -------
  16. rules : list of rules.
  17. """
  18. # XXX todo: check the case where tree is build on subset of features,
  19. # ie max_features != None
  20. tree_ = tree.tree_
  21. feature_name = [
  22. feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
  23. for i in tree_.feature
  24. ]
  25. rules = []
  26. def recurse(node, base_name):
  27. if tree_.feature[node] != _tree.TREE_UNDEFINED:
  28. name = feature_name[node]
  29. symbol = '<='
  30. symbol2 = '>'
  31. threshold = tree_.threshold[node]
  32. if round_thresholds:
  33. threshold = np.round(threshold, decimals=5)
  34. text = base_name + ["{} {} {}".format(name, symbol, threshold)]
  35. recurse(tree_.children_left[node], text)
  36. text = base_name + ["{} {} {}".format(name, symbol2,
  37. threshold)]
  38. recurse(tree_.children_right[node], text)
  39. else:
  40. rule = str.join(' and ', base_name)
  41. rule = (rule if rule != ''
  42. else ' == '.join([feature_names[0]] * 2))
  43. # a rule selecting all is set to "c0==c0"
  44. if prediction_values:
  45. rules.append((rule, tree_.value[node][0].tolist()))
  46. else:
  47. rules.append(rule)
  48. recurse(0, [])
  49. return rules if len(rules) > 0 else 'True'
  50. def tree_to_code(clf, feature_names):
  51. '''Prints a tree with a single split
  52. '''
  53. n_nodes = clf.tree_.node_count
  54. children_left = clf.tree_.children_left
  55. children_right = clf.tree_.children_right
  56. feature = clf.tree_.feature
  57. threshold = clf.tree_.threshold
  58. node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
  59. is_leaves = np.zeros(shape=n_nodes, dtype=bool)
  60. stack = [(0, 0)] # start with the root node id (0) and its depth (0)
  61. s = ''
  62. while len(stack) > 0:
  63. # `pop` ensures each node is only visited once
  64. node_id, depth = stack.pop()
  65. node_depth[node_id] = depth
  66. # If the left and right child of a node is not the same we have a split
  67. # node
  68. is_split_node = children_left[node_id] != children_right[node_id]
  69. # If a split node, append left and right children and depth to `stack`
  70. # so we can loop through them
  71. if is_split_node:
  72. stack.append((children_left[node_id], depth + 1))
  73. stack.append((children_right[node_id], depth + 1))
  74. else:
  75. is_leaves[node_id] = True
  76. # print("The binary tree structure has {n} nodes and has "
  77. # "the following tree structure:\n".format(n=n_nodes))
  78. for i in range(n_nodes):
  79. if is_leaves[i]:
  80. pass
  81. # print("{space}node={node} is a leaf node.".format(
  82. # space=node_depth[i] * "\t", node=i))
  83. else:
  84. s += f"{feature_names[feature[i]]} <= {threshold[i]}"
  85. return f"\033[96m{s}\033[00m\n"
  86. def itemsets_to_rules(itemsets: List[Tuple]) -> List[str]:
  87. itemsets_clean = list(filter(lambda it: it != 'null' and 'All' not in ''.join(it), itemsets))
  88. f = lambda itemset: ' and '.join([single_discretized_feature_to_rule(item) for item in itemset])
  89. return list(map(f, itemsets_clean))
  90. def dict_to_rule(rule, clf_feature_dict):
  91. """
  92. Function to accept rule dict and convert to Rule object
  93. Parameters:
  94. rule: list of dict of schema
  95. [
  96. {
  97. 'feature': int,
  98. 'operator': str,
  99. 'value': float
  100. },
  101. ]
  102. """
  103. output = ''
  104. for condition in rule:
  105. output += '{} {} {} and '.format(
  106. clf_feature_dict[int(condition['feature'])],
  107. condition['operator'],
  108. condition['pivot']
  109. )
  110. return output[:-5]
  111. def single_discretized_feature_to_rule(feat: str) -> str:
  112. # categorical feature
  113. if '_to_' not in feat:
  114. return f'{feat} > 0.5'
  115. # discretized numeric feature
  116. feat_split = feat.split('_to_')
  117. upper_value = feat_split[-1]
  118. lower_value = feat_split[-2].split('_')[-1]
  119. lower_to_upper_len = 1 + len(lower_value) + 4 + len(upper_value)
  120. feature_name = feat[:-lower_to_upper_len]
  121. if lower_value == '-inf':
  122. rule = f'{feature_name} <= {upper_value}'
  123. elif upper_value == 'inf':
  124. rule = f'{feature_name} > {lower_value}'
  125. else:
  126. rule = f'{feature_name} > {lower_value} and {feature_name} <= {upper_value}'
  127. return rule
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...