Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

hierarchical_shrinkage.py 18 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
  1. import time
  2. from copy import deepcopy
  3. from typing import List
  4. import numpy as np
  5. from sklearn import datasets
  6. from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin
  7. from sklearn.metrics import r2_score, mean_squared_error, log_loss
  8. from sklearn.model_selection import cross_val_score, KFold
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier, export_text
  11. from sklearn.ensemble import (
  12. GradientBoostingClassifier,
  13. GradientBoostingRegressor,
  14. RandomForestRegressor,
  15. )
  16. from imodels.util import checks
  17. from imodels.util.arguments import check_fit_arguments
  18. from imodels.util.tree import compute_tree_complexity
  19. class HSTree(BaseEstimator):
  20. def __init__(
  21. self,
  22. estimator_: BaseEstimator = DecisionTreeClassifier(max_leaf_nodes=20),
  23. reg_param: float = 1,
  24. shrinkage_scheme_: str = "node_based",
  25. max_leaf_nodes: int = None,
  26. random_state: int = None,
  27. ):
  28. """HSTree (Tree with hierarchical shrinkage applied).
  29. Hierarchical shinkage is an extremely fast post-hoc regularization method which works on any decision tree (or tree-based ensemble, such as Random Forest).
  30. It does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors (using a single regularization parameter).
  31. Experiments over a wide variety of datasets show that hierarchical shrinkage substantially increases the predictive performance of individual decision trees and decision-tree ensembles.
  32. https://arxiv.org/abs/2202.00858
  33. Params
  34. ------
  35. estimator_: sklearn tree or tree ensemble model (e.g. RandomForest or GradientBoosting)
  36. Defaults to CART Classification Tree with 20 max leaf nodes
  37. Note: this estimator will be directly modified
  38. reg_param: float
  39. Higher is more regularization (can be arbitrarily large, should not be < 0)
  40. shrinkage_scheme: str
  41. Experimental: Used to experiment with different forms of shrinkage. options are:
  42. (i) node_based shrinks based on number of samples in parent node
  43. (ii) leaf_based only shrinks leaf nodes based on number of leaf samples
  44. (iii) constant shrinks every node by a constant lambda
  45. max_leaf_nodes: int
  46. If estimator is None, then max_leaf_nodes is passed to the default decision tree
  47. """
  48. super().__init__()
  49. self.reg_param = reg_param
  50. self.estimator_ = estimator_
  51. self.shrinkage_scheme_ = shrinkage_scheme_
  52. self.random_state = random_state
  53. if checks.check_is_fitted(self.estimator_):
  54. self._shrink()
  55. if max_leaf_nodes is not None:
  56. self.estimator_.max_leaf_nodes = max_leaf_nodes
  57. self.estimator_.random_state = random_state
  58. def get_params(self, deep=True):
  59. d = {
  60. "reg_param": self.reg_param,
  61. "estimator_": self.estimator_,
  62. "shrinkage_scheme_": self.shrinkage_scheme_,
  63. "max_leaf_nodes": self.estimator_.max_leaf_nodes,
  64. }
  65. if deep:
  66. return deepcopy(d)
  67. return d
  68. def fit(self, X, y, sample_weight=None, *args, **kwargs):
  69. # remove feature_names if it exists (note: only works as keyword-arg)
  70. # None returned if not passed
  71. feature_names = kwargs.pop("feature_names", None)
  72. X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
  73. if feature_names is not None:
  74. self.feature_names = feature_names
  75. self.estimator_ = self.estimator_.fit(
  76. X, y, *args, sample_weight=sample_weight, **kwargs
  77. )
  78. self._shrink()
  79. # compute complexity
  80. if hasattr(self.estimator_, "tree_"):
  81. self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
  82. elif hasattr(self.estimator_, "estimators_"):
  83. self.complexity_ = 0
  84. for i in range(len(self.estimator_.estimators_)):
  85. t = deepcopy(self.estimator_.estimators_[i])
  86. if isinstance(t, np.ndarray):
  87. assert t.size == 1, "multiple trees stored under tree_?"
  88. t = t[0]
  89. self.complexity_ += compute_tree_complexity(t.tree_)
  90. return self
  91. def _shrink_tree(
  92. self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0
  93. ):
  94. """Shrink the tree"""
  95. if reg_param is None:
  96. reg_param = 1.0
  97. left = tree.children_left[i]
  98. right = tree.children_right[i]
  99. is_leaf = left == right
  100. n_samples = tree.weighted_n_node_samples[i]
  101. if isinstance(self, RegressorMixin) or isinstance(
  102. self.estimator_, GradientBoostingClassifier
  103. ):
  104. val = deepcopy(tree.value[i, :, :])
  105. else: # If classification, normalize to probability vector
  106. val = tree.value[i, :, :] / n_samples
  107. # Step 1: Update cum_sum
  108. # if root
  109. if parent_val is None and parent_num is None:
  110. cum_sum = val
  111. # if has parent
  112. else:
  113. if self.shrinkage_scheme_ == "node_based":
  114. val_new = (val - parent_val) / (1 + reg_param / parent_num)
  115. elif self.shrinkage_scheme_ == "constant":
  116. val_new = (val - parent_val) / (1 + reg_param)
  117. else: # leaf_based
  118. val_new = 0
  119. cum_sum += val_new
  120. # Step 2: Update node values
  121. if (
  122. self.shrinkage_scheme_ == "node_based"
  123. or self.shrinkage_scheme_ == "constant"
  124. ):
  125. tree.value[i, :, :] = cum_sum
  126. else: # leaf_based
  127. if is_leaf: # update node values if leaf_based
  128. root_val = tree.value[0, :, :]
  129. tree.value[i, :, :] = root_val + (val - root_val) / (
  130. 1 + reg_param / n_samples
  131. )
  132. else:
  133. tree.value[i, :, :] = val
  134. # Step 3: Recurse if not leaf
  135. if not is_leaf:
  136. self._shrink_tree(
  137. tree,
  138. reg_param,
  139. left,
  140. parent_val=val,
  141. parent_num=n_samples,
  142. cum_sum=deepcopy(cum_sum),
  143. )
  144. self._shrink_tree(
  145. tree,
  146. reg_param,
  147. right,
  148. parent_val=val,
  149. parent_num=n_samples,
  150. cum_sum=deepcopy(cum_sum),
  151. )
  152. # edit the non-leaf nodes for later visualization (doesn't effect predictions)
  153. return tree
  154. def _shrink(self):
  155. if hasattr(self.estimator_, "tree_"):
  156. self._shrink_tree(self.estimator_.tree_, self.reg_param)
  157. elif hasattr(self.estimator_, "estimators_"):
  158. for t in self.estimator_.estimators_:
  159. if isinstance(t, np.ndarray):
  160. assert t.size == 1, "multiple trees stored under tree_?"
  161. t = t[0]
  162. self._shrink_tree(t.tree_, self.reg_param)
  163. def predict(self, X, *args, **kwargs):
  164. return self.estimator_.predict(X, *args, **kwargs)
  165. def predict_proba(self, X, *args, **kwargs):
  166. if hasattr(self.estimator_, "predict_proba"):
  167. return self.estimator_.predict_proba(X, *args, **kwargs)
  168. else:
  169. return NotImplemented
  170. def score(self, X, y, *args, **kwargs):
  171. if hasattr(self.estimator_, "score"):
  172. return self.estimator_.score(X, y, *args, **kwargs)
  173. else:
  174. return NotImplemented
  175. def __str__(self):
  176. # check if fitted
  177. if not checks.check_is_fitted(self.estimator_):
  178. s = self.__class__.__name__
  179. s += "("
  180. s += "est="
  181. s += repr(self.estimator_)
  182. s += ", "
  183. s += "reg_param="
  184. s += str(self.reg_param)
  185. s += ")"
  186. return s
  187. else:
  188. s = "> ------------------------------\n"
  189. s += "> Decision Tree with Hierarchical Shrinkage\n"
  190. s += "> \tPrediction is made by looking at the value in the appropriate leaf of the tree\n"
  191. s += "> ------------------------------" + "\n"
  192. if hasattr(self, "feature_names") and self.feature_names is not None:
  193. return s + export_text(
  194. self.estimator_, feature_names=self.feature_names, show_weights=True
  195. )
  196. else:
  197. return s + export_text(self.estimator_, show_weights=True)
  198. def __repr__(self):
  199. # s = self.__class__.__name__
  200. # s += "("
  201. # s += "estimator_="
  202. # s += repr(self.estimator_)
  203. # s += ", "
  204. # s += "reg_param="
  205. # s += str(self.reg_param)
  206. # s += ", "
  207. # s += "shrinkage_scheme_="
  208. # s += self.shrinkage_scheme_
  209. # s += ")"
  210. # return s
  211. attr_list = ["estimator_", "reg_param", "shrinkage_scheme_"]
  212. s = self.__class__.__name__
  213. s += "("
  214. for attr in attr_list:
  215. s += attr + "=" + repr(getattr(self, attr)) + ", "
  216. s = s[:-2] + ")"
  217. return s
  218. class HSTreeRegressor(HSTree, RegressorMixin):
  219. def __init__(
  220. self,
  221. estimator_: BaseEstimator = DecisionTreeRegressor(max_leaf_nodes=20),
  222. reg_param: float = 1,
  223. shrinkage_scheme_: str = "node_based",
  224. max_leaf_nodes: int = None,
  225. random_state: int = None,
  226. ):
  227. super().__init__(
  228. estimator_=estimator_,
  229. reg_param=reg_param,
  230. shrinkage_scheme_=shrinkage_scheme_,
  231. max_leaf_nodes=max_leaf_nodes,
  232. random_state=random_state,
  233. )
  234. class HSTreeClassifier(HSTree, ClassifierMixin):
  235. def __init__(
  236. self,
  237. estimator_: BaseEstimator = DecisionTreeClassifier(max_leaf_nodes=20),
  238. reg_param: float = 1,
  239. shrinkage_scheme_: str = "node_based",
  240. max_leaf_nodes: int = None,
  241. random_state: int = None,
  242. ):
  243. super().__init__(
  244. estimator_=estimator_,
  245. reg_param=reg_param,
  246. shrinkage_scheme_=shrinkage_scheme_,
  247. max_leaf_nodes=max_leaf_nodes,
  248. random_state=random_state,
  249. )
  250. def _get_cv_criterion(scorer):
  251. y_true = np.random.binomial(n=1, p=0.5, size=100)
  252. y_pred_good = y_true
  253. y_pred_bad = np.random.uniform(0, 1, 100)
  254. score_good = scorer(y_true, y_pred_good)
  255. score_bad = scorer(y_true, y_pred_bad)
  256. if score_good > score_bad:
  257. return np.argmax
  258. elif score_good < score_bad:
  259. return np.argmin
  260. class HSTreeClassifierCV(HSTreeClassifier):
  261. def __init__(
  262. self,
  263. estimator_: BaseEstimator = None,
  264. reg_param_list: List[float] = [0, 0.1, 1, 10, 50, 100, 500],
  265. shrinkage_scheme_: str = "node_based",
  266. max_leaf_nodes: int = 20,
  267. cv: int = 3,
  268. scoring=None,
  269. *args,
  270. **kwargs
  271. ):
  272. """Cross-validation is used to select the best regularization parameter for hierarchical shrinkage.
  273. Params
  274. ------
  275. estimator_
  276. Sklearn estimator (already initialized).
  277. If no estimator_ is passed, sklearn decision tree is used
  278. max_rules
  279. If estimator is None, then max_leaf_nodes is passed to the default decision tree
  280. args, kwargs
  281. Note: args, kwargs are not used but left so that imodels-experiments can still pass redundant args.
  282. """
  283. if estimator_ is None:
  284. estimator_ = DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
  285. super().__init__(estimator_, reg_param=None)
  286. self.reg_param_list = np.array(reg_param_list)
  287. self.cv = cv
  288. self.scoring = scoring
  289. self.shrinkage_scheme_ = shrinkage_scheme_
  290. # print('estimator', self.estimator_,
  291. # 'checks.check_is_fitted(estimator)', checks.check_is_fitted(self.estimator_))
  292. # if checks.check_is_fitted(self.estimator_):
  293. # raise Warning('Passed an already fitted estimator,'
  294. # 'but shrinking not applied until fit method is called.')
  295. def fit(self, X, y, *args, **kwargs):
  296. self.scores_ = [[] for _ in self.reg_param_list]
  297. scorer = kwargs.get("scoring", log_loss)
  298. kf = KFold(n_splits=self.cv)
  299. for train_index, test_index in kf.split(X):
  300. X_out, y_out = X[test_index, :], y[test_index]
  301. X_in, y_in = X[train_index, :], y[train_index]
  302. base_est = deepcopy(self.estimator_)
  303. base_est.fit(X_in, y_in)
  304. for i, reg_param in enumerate(self.reg_param_list):
  305. est_hs = HSTreeClassifier(base_est, reg_param)
  306. est_hs.fit(X_in, y_in, *args, **kwargs)
  307. self.scores_[i].append(
  308. scorer(y_out, est_hs.predict_proba(X_out)))
  309. self.scores_ = [np.mean(s) for s in self.scores_]
  310. cv_criterion = _get_cv_criterion(scorer)
  311. self.reg_param = self.reg_param_list[cv_criterion(self.scores_)]
  312. super().fit(X=X, y=y, *args, **kwargs)
  313. def __repr__(self):
  314. attr_list = [
  315. "estimator_",
  316. "reg_param_list",
  317. "shrinkage_scheme_",
  318. "cv",
  319. "scoring",
  320. ]
  321. s = self.__class__.__name__
  322. s += "("
  323. for attr in attr_list:
  324. s += attr + "=" + repr(getattr(self, attr)) + ", "
  325. s = s[:-2] + ")"
  326. return s
  327. class HSTreeRegressorCV(HSTreeRegressor):
  328. def __init__(
  329. self,
  330. estimator_: BaseEstimator = None,
  331. reg_param_list: List[float] = [0, 0.1, 1, 10, 50, 100, 500],
  332. shrinkage_scheme_: str = "node_based",
  333. max_leaf_nodes: int = 20,
  334. cv: int = 3,
  335. scoring=None,
  336. *args,
  337. **kwargs
  338. ):
  339. """Cross-validation is used to select the best regularization parameter for hierarchical shrinkage.
  340. Params
  341. ------
  342. estimator_
  343. Sklearn estimator (already initialized).
  344. If no estimator_ is passed, sklearn decision tree is used
  345. max_rules
  346. If estimator is None, then max_leaf_nodes is passed to the default decision tree
  347. args, kwargs
  348. Note: args, kwargs are not used but left so that imodels-experiments can still pass redundant args.
  349. """
  350. if estimator_ is None:
  351. estimator_ = DecisionTreeRegressor(max_leaf_nodes=max_leaf_nodes)
  352. super().__init__(estimator_, reg_param=None)
  353. self.reg_param_list = np.array(reg_param_list)
  354. self.cv = cv
  355. self.scoring = scoring
  356. self.shrinkage_scheme_ = shrinkage_scheme_
  357. # print('estimator', self.estimator_,
  358. # 'checks.check_is_fitted(estimator)', checks.check_is_fitted(self.estimator_))
  359. # if checks.check_is_fitted(self.estimator_):
  360. # raise Warning('Passed an already fitted estimator,'
  361. # 'but shrinking not applied until fit method is called.')
  362. def fit(self, X, y, *args, **kwargs):
  363. self.scores_ = [[] for _ in self.reg_param_list]
  364. kf = KFold(n_splits=self.cv)
  365. scorer = kwargs.get("scoring", mean_squared_error)
  366. for train_index, test_index in kf.split(X):
  367. X_out, y_out = X[test_index, :], y[test_index]
  368. X_in, y_in = X[train_index, :], y[train_index]
  369. base_est = deepcopy(self.estimator_)
  370. base_est.fit(X_in, y_in)
  371. for i, reg_param in enumerate(self.reg_param_list):
  372. est_hs = HSTreeRegressor(base_est, reg_param)
  373. est_hs.fit(X_in, y_in)
  374. self.scores_[i].append(scorer(est_hs.predict(X_out), y_out))
  375. self.scores_ = [np.mean(s) for s in self.scores_]
  376. cv_criterion = _get_cv_criterion(scorer)
  377. self.reg_param = self.reg_param_list[cv_criterion(self.scores_)]
  378. super().fit(X=X, y=y, *args, **kwargs)
  379. def __repr__(self):
  380. attr_list = [
  381. "estimator_",
  382. "reg_param_list",
  383. "shrinkage_scheme_",
  384. "cv",
  385. "scoring",
  386. ]
  387. s = self.__class__.__name__
  388. s += "("
  389. for attr in attr_list:
  390. s += attr + "=" + repr(getattr(self, attr)) + ", "
  391. s = s[:-2] + ")"
  392. return s
  393. if __name__ == "__main__":
  394. np.random.seed(15)
  395. # X, y = datasets.fetch_california_housing(return_X_y=True) # regression
  396. # X, y = datasets.load_breast_cancer(return_X_y=True) # binary classification
  397. X, y = datasets.load_diabetes(return_X_y=True) # regression
  398. # X = np.random.randn(500, 10)
  399. # y = (X[:, 0] > 0).astype(float) + (X[:, 1] > 1).astype(float)
  400. X_train, X_test, y_train, y_test = train_test_split(
  401. X, y, test_size=0.33, random_state=10
  402. )
  403. print("X.shape", X.shape)
  404. print("ys", np.unique(y_train))
  405. # m = HSTree(estimator_=DecisionTreeClassifier(), reg_param=0.1)
  406. # m = DecisionTreeClassifier(max_leaf_nodes = 20,random_state=1, max_features=None)
  407. # m = DecisionTreeClassifier(random_state=42)
  408. m = GradientBoostingRegressor(random_state=10, n_estimators=5)
  409. # print('best alpha', m.reg_param)
  410. m.fit(X_train, y_train)
  411. # m.predict_proba(X_train) # just run this
  412. print("score", r2_score(y_test, m.predict(X_test)))
  413. print("running again....")
  414. # x = DecisionTreeRegressor(random_state = 42, ccp_alpha = 0.3)
  415. # x.fit(X_train,y_train)
  416. # m = HSTree(estimator_=DecisionTreeRegressor(random_state=42, max_features=None), reg_param=10)
  417. # m = HSTree(estimator_=DecisionTreeClassifier(random_state=42, max_features=None), reg_param=0)
  418. # m = HSTreeRegressorCV(
  419. # estimator_=DecisionTreeClassifier(random_state=42),
  420. # shrinkage_scheme_="node_based",
  421. # reg_param_list=[0.1, 1, 2, 5, 10, 25, 50, 100, 500],
  422. # )
  423. # m = ShrunkTreeCV(estimator_=DecisionTreeClassifier())
  424. m = HSTreeRegressor(m)
  425. print("score", r2_score(y_test, m.predict(X_test)))
  426. m = HSTreeRegressor(
  427. estimator_=GradientBoostingRegressor(
  428. random_state=10,
  429. n_estimators=5,
  430. ),
  431. reg_param=1,
  432. )
  433. m.fit(X_train, y_train)
  434. print("best alpha", m.reg_param)
  435. # m.predict_proba(X_train) # just run this
  436. # print('score', m.score(X_test, y_test))
  437. print("score", r2_score(y_test, m.predict(X_test)))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...