Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

local_stumps.py 8.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
  1. import numpy as np
  2. class LocalDecisionStump:
  3. """
  4. An object that implements a callable local decision stump function and
  5. that also includes some meta-data that allows for an API to interact with
  6. other methods in the package.
  7. A local decision stump is a tri-valued function that is zero outside of a
  8. rectangular region, and on that region, takes either a positive or negative
  9. value, depending on whether a single designated feature is above or below a
  10. threshold. For more information on what a local decision stump is, refer to
  11. our paper.
  12. Parameters
  13. ----------
  14. feature: int
  15. Feature used in the decision stump
  16. threshold: float
  17. Threshold used in the decision stump
  18. left_val: float
  19. The value taken when x_k <= threshold
  20. right_val: float
  21. The value taken when x_k > threshold
  22. a_features: list of ints
  23. List of ancestor feature indices (ordered from highest ancestor to
  24. lowest)
  25. a_thresholds: list of floats
  26. List of ancestor thresholds (ordered from highest ancestor to lowest)
  27. a_signs: list of bools
  28. List of signs indicating whether the current node is in the left child
  29. (False) or right child (True) of the ancestor nodes (ordered from
  30. highest ancestor to lowest)
  31. """
  32. def __init__(self, feature, threshold, left_val, right_val, a_features,
  33. a_thresholds, a_signs):
  34. self.feature = feature
  35. self.threshold = threshold
  36. self.left_val = left_val
  37. self.right_val = right_val
  38. self.a_features = a_features
  39. self.a_thresholds = a_thresholds
  40. self.a_signs = a_signs
  41. def __call__(self, data):
  42. """
  43. Return values of the local decision stump function on an input data
  44. matrix with samples as rows
  45. Parameters
  46. ----------
  47. data: array-like of shape (n_samples, n_features)
  48. Data matrix to feed into the local decision stump function
  49. Returns
  50. -------
  51. values: array-like of shape (n_samples,)
  52. Function values on the data
  53. """
  54. root_to_stump_path_indicators = \
  55. _compare_all(data, self.a_features, np.array(self.a_thresholds),
  56. np.array(self.a_signs))
  57. in_node = np.all(root_to_stump_path_indicators, axis=1).astype(int)
  58. is_right = _compare(data, self.feature, self.threshold).astype(int)
  59. values = in_node * (is_right * self.right_val +
  60. (1 - is_right) * self.left_val)
  61. return values
  62. def __repr__(self):
  63. return f"LocalDecisionStump(feature={self.feature}, " \
  64. f"threshold={self.threshold}, left_val={self.left_val}, " \
  65. f"right_val={self.right_val}, a_features={self.a_features}, " \
  66. f"a_thresholds={self.a_thresholds}, " \
  67. f"a_signs={self.a_signs})"
  68. def get_depth(self):
  69. """
  70. Get depth of the local decision stump, i.e. count how many ancenstor
  71. nodes it has. The root node has depth 0.
  72. """
  73. return len(self.a_features)
  74. def make_stump(node_no, tree_struct, parent_stump, is_right_child,
  75. normalize=False):
  76. """
  77. Create a single local decision stump corresponding to a node in a
  78. scikit-learn tree structure object. The nonzero values of the stump are
  79. chosen so that the vector of local decision stump values over the training
  80. set (used to fit the tree) is orthogonal to those of all ancestor nodes.
  81. Parameters
  82. ----------
  83. node_no: int
  84. The index of the node
  85. tree_struct: object
  86. The scikit-learn tree object
  87. parent_stump: LocalDecisionStump object
  88. The local decision stump corresponding to the parent of the node in q
  89. uestion
  90. is_right_child: bool
  91. True if the new node is the right child of the parent node, False
  92. otherwise
  93. normalize: bool
  94. Flag. If set to True, then divide the nonzero function values by
  95. sqrt(n_samples in node) so that the vector of function values on the
  96. training set has unit norm. If False, then do not divide, so that the
  97. vector of function values on the training set has norm equal to
  98. n_samples in node.
  99. Returns
  100. -------
  101. """
  102. # Get features, thresholds and signs for ancestors
  103. if parent_stump is None: # If root node
  104. a_features = []
  105. a_thresholds = []
  106. a_signs = []
  107. else:
  108. a_features = parent_stump.a_features + [parent_stump.feature]
  109. a_thresholds = parent_stump.a_thresholds + [parent_stump.threshold]
  110. a_signs = parent_stump.a_signs + [is_right_child]
  111. # Get indices for left and right children of the node in question
  112. left_child = tree_struct.children_left[node_no]
  113. right_child = tree_struct.children_right[node_no]
  114. # Get quantities relevant to the node in question
  115. feature = tree_struct.feature[node_no]
  116. threshold = tree_struct.threshold[node_no]
  117. left_size = tree_struct.weighted_n_node_samples[left_child]
  118. right_size = tree_struct.weighted_n_node_samples[right_child]
  119. parent_size = tree_struct.weighted_n_node_samples[node_no]
  120. normalization = parent_size if normalize else 1
  121. left_val = - np.sqrt(right_size / (left_size * normalization))
  122. right_val = np.sqrt(left_size / (right_size * normalization))
  123. return LocalDecisionStump(feature, threshold, left_val, right_val,
  124. a_features, a_thresholds, a_signs)
  125. def make_stumps(tree_struct, normalize=False):
  126. """
  127. Create a collection of local decision stumps corresponding to all internal
  128. nodes in a scikit-learn tree structure object.
  129. Parameters
  130. ----------
  131. tree_struct: object
  132. The scikit-learn tree object
  133. normalize: bool
  134. Flag. If set to True, then divide the nonzero function values by
  135. sqrt(n_samples in node) so that the vector of function values on the
  136. training set has unit norm. If False, then do not divide, so that the
  137. vector of function values on the training set has norm equal to
  138. n_samples in node.
  139. Returns
  140. -------
  141. stumps: list of LocalDecisionStump objects
  142. The local decision stumps corresponding to all internal node in the
  143. tree structure
  144. """
  145. stumps = []
  146. def make_stump_iter(node_no, tree_struct, parent_stump, is_right_child,
  147. normalize, stumps):
  148. """
  149. Helper function for iteratively making local decision stump objects and
  150. appending them to the list stumps.
  151. """
  152. new_stump = make_stump(node_no, tree_struct, parent_stump,
  153. is_right_child, normalize)
  154. stumps.append(new_stump)
  155. left_child = tree_struct.children_left[node_no]
  156. right_child = tree_struct.children_right[node_no]
  157. if tree_struct.feature[left_child] != -2: # is not leaf
  158. make_stump_iter(left_child, tree_struct, new_stump, False,
  159. normalize, stumps)
  160. if tree_struct.feature[right_child] != -2: # is not leaf
  161. make_stump_iter(right_child, tree_struct, new_stump, True,
  162. normalize, stumps)
  163. make_stump_iter(0, tree_struct, None, None, normalize, stumps)
  164. return stumps
  165. def tree_feature_transform(stumps, X):
  166. """
  167. Transform the data matrix X using a mapping derived from a collection of
  168. local decision stump functions.
  169. If the list of stumps is empty, return an array of shape (0, n_samples).
  170. Parameters
  171. ----------
  172. stumps: list of LocalDecisionStump objects
  173. List of stump functions to use to transform data
  174. X: array-like of shape (n_samples, n_features)
  175. Original data matrix
  176. Returns
  177. -------
  178. X_transformed: array-like of shape (n_samples, n_stumps)
  179. Transformed data matrix
  180. """
  181. transformed_feature_vectors = [np.empty((X.shape[0], 0))]
  182. for stump in stumps:
  183. transformed_feature_vec = stump(X)[:, np.newaxis]
  184. transformed_feature_vectors.append(transformed_feature_vec)
  185. X_transformed = np.hstack(transformed_feature_vectors)
  186. return X_transformed
  187. def _compare(data, k, threshold, sign=True):
  188. """
  189. Obtain indicator vector for the samples with k-th feature > threshold
  190. """
  191. if sign:
  192. return data[:, k] > threshold
  193. else:
  194. return data[:, k] <= threshold
  195. def _compare_all(data, ks, thresholds, signs):
  196. """
  197. Obtain indicator vector for the samples with k-th feature > threshold or
  198. <= threshold (depending on sign) for all k in ks
  199. """
  200. return ~np.logical_xor(data[:, ks] > thresholds, signs)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...