Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

viz_utils.py 5.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  1. import numpy as np
  2. import pandas as pd
  3. from collections import namedtuple
  4. from sklearn import __version__
  5. from sklearn.base import ClassifierMixin, RegressorMixin
  6. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  7. from sklearn.tree._tree import Tree
  8. TreeData = namedtuple('TreeData', 'left_child right_child feature threshold impurity n_node_samples weighted_n_node_samples missing_go_to_left')
  9. def _extract_arrays_from_figs_tree(figs_tree):
  10. """Takes in a FIGS tree and recursively converts it to arrays that we can later use to build a sklearn decision tree object
  11. """
  12. tree_data = TreeData(
  13. left_child=[],
  14. right_child=[],
  15. feature=[],
  16. threshold=[],
  17. impurity=[],
  18. n_node_samples=[],
  19. weighted_n_node_samples=[],
  20. missing_go_to_left=[],
  21. )
  22. value_sklearn_array = []
  23. def _update_node(node):
  24. if node is None:
  25. return
  26. node_id_left = node_id_right = -1
  27. feature = threshold = -2
  28. value_sklearn = node.value_sklearn
  29. has_children = node.left is not None
  30. if has_children:
  31. node_id_left = node.left.node_id
  32. node_id_right = node.right.node_id
  33. feature = node.feature
  34. threshold = node.threshold
  35. tree_data.left_child.append(node_id_left)
  36. tree_data.right_child.append(node_id_right)
  37. tree_data.feature.append(feature)
  38. tree_data.threshold.append(threshold)
  39. tree_data.impurity.append(node.impurity)
  40. tree_data.n_node_samples.append(np.sum(value_sklearn))
  41. tree_data.weighted_n_node_samples.append(np.sum(value_sklearn)) # TODO add sample weights
  42. tree_data.missing_go_to_left.append(1)
  43. value_sklearn_array.append(value_sklearn)
  44. if has_children:
  45. _update_node(node.left)
  46. _update_node(node.right)
  47. _update_node(figs_tree)
  48. return tree_data, np.array(value_sklearn_array)
  49. def extract_sklearn_tree_from_figs(figs, tree_num, n_classes, with_leaf_predictions=False):
  50. """Takes in a FIGS model and convert tree tree_num to a sklearn decision tree
  51. """
  52. try:
  53. figs_tree = figs.trees_[tree_num]
  54. except:
  55. raise AttributeError(f'Can not load tree_num = {tree_num}!')
  56. tree_data_namedtuple, value_sklearn_array = _extract_arrays_from_figs_tree(figs_tree)
  57. # manipulate tree_data_namedtuple into the numpy array of tuples that sklearn expects for use with __setstate__()
  58. df_tree_data = pd.DataFrame(tree_data_namedtuple._asdict())
  59. tree_data_list_of_tuples = list(df_tree_data.itertuples(index=False, name=None))
  60. _dtypes = np.dtype([('left_child', 'i8'), ('right_child', 'i8'), ('feature', 'i8'), ('threshold', 'f8'), ('impurity', 'f8'), ('n_node_samples', 'i8'), ('weighted_n_node_samples', 'f8'), ('missing_go_to_left', 'u1')])
  61. tree_data_array = np.array(tree_data_list_of_tuples, dtype=_dtypes)
  62. # reshape value_sklearn_array to match the expected shape of (n_nodes,1,2) for values
  63. value_sklearns = value_sklearn_array.reshape(value_sklearn_array.shape[0], 1, value_sklearn_array.shape[1])
  64. if n_classes == 1:
  65. value_sklearns = np.ascontiguousarray(value_sklearns[:, :, 0:1])
  66. # get the max_depth
  67. def get_max_depth(node):
  68. if node is None:
  69. return -1
  70. else:
  71. return 1 + max(get_max_depth(node.left), get_max_depth(node.right))
  72. max_depth = get_max_depth(figs_tree)
  73. # get other variables needed for the sklearn.tree._tree.Tree constructor and __setstate__() calls
  74. # n_samples = np.sum(figs_tree.value_sklearn)
  75. node_count = len(tree_data_array)
  76. features = np.array(tree_data_namedtuple.feature)
  77. n_features = np.unique(features[np.where( 0 <= features )]).size
  78. n_classes_array = np.array([n_classes], dtype=int)
  79. n_outputs = 1
  80. # make dict to pass to __setstate__()
  81. _state = {'max_depth': max_depth,
  82. 'node_count': node_count,
  83. 'nodes': tree_data_array,
  84. 'values': value_sklearns,
  85. 'n_features_in_': figs.n_features_in_,
  86. # WARNING this circumvents
  87. # UserWarning: Trying to unpickle estimator DecisionTreeClassifier from version pre-0.18 when using version
  88. # https://github.com/scikit-learn/scikit-learn/blob/53acd0fe52cb5d8c6f5a86a1fc1352809240b68d/sklearn/base.py#L279
  89. '_sklearn_version': __version__,
  90. }
  91. tree = Tree(n_features=n_features, n_classes=n_classes_array, n_outputs=n_outputs)
  92. # https://github.com/scikit-learn/scikit-learn/blob/3850935ea610b5231720fdf865c837aeff79ab1b/sklearn/tree/_tree.pyx#L677
  93. tree.__setstate__(_state)
  94. # add the tree_ for the dt __setstate__()
  95. # note the trailing underscore also trips the sklearn_is_fitted protections
  96. _state['tree_'] = tree
  97. _state['classes_'] = np.arange(n_classes)
  98. _state['n_outputs_'] = n_outputs
  99. # construct sklearn object and __setstate__()
  100. if isinstance(figs, ClassifierMixin):
  101. dt = DecisionTreeClassifier(max_depth=max_depth)
  102. elif isinstance(figs, RegressorMixin):
  103. dt = DecisionTreeRegressor(max_depth=max_depth)
  104. try:
  105. dt.__setstate__(_state);
  106. except:
  107. raise Exception(f'Did not successfully run __setstate__() when translating to {type(dt)}, did sklearn update?')
  108. if not with_leaf_predictions:
  109. return dt
  110. else:
  111. leaf_values_dict = {}
  112. def _read_node(node):
  113. if node is None:
  114. return None
  115. elif node.left is None and node.right is None:
  116. leaf_values_dict[node.node_id] = node.value[0][0]
  117. _read_node(node.left)
  118. _read_node(node.right)
  119. _read_node(figs_tree)
  120. return dt, leaf_values_dict
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...