Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

shrinkage.html 14 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  1. <!doctype html>
  2. <html lang="en">
  3. <head>
  4. <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
  5. <script src="https://cdn.plot.ly/plotly-2.6.3.min.js"></script>
  6. <script type="text/javascript">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>
  7. <meta charset="utf-8">
  8. <meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
  9. <meta name="generator" content="pdoc 0.10.0" />
  10. <link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css" integrity="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs=" crossorigin>
  11. <link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css" integrity="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg=" crossorigin>
  12. <link rel="stylesheet preload" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css" crossorigin>
  13. <style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
  14. <style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
  15. <style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
  16. <script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js" integrity="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8=" crossorigin></script>
  17. <script>window.addEventListener('DOMContentLoaded', () => hljs.initHighlighting())</script>
  18. </head>
  19. <body>
  20. <main>
  21. <article id="content">
  22. <section id="section-intro">
  23. <p align="center">
  24. <img align="center" width=60% src="https://csinva.io/imodels/img/imodels_logo.svg?sanitize=True&kill_cache=1"> </img>
  25. <br/>
  26. <div class="article" style="padding-right: 2%; padding-left: 2%;">
  27. <p align="center">
  28. <h1 text-align="center" style="display: inline;padding-bottom: 0px;"> Hierarchical shrinkage: improving the accuracy and interpretability
  29. of tree-based methods </h1>
  30. <h3 style="color:gray;padding-top:0px;">Abhineet Agarwal*, Yan Shuo Tan*, Omer Ronen, Chandan Singh, Bin
  31. Yu </h3>
  32. <hr>
  33. </p>
  34. <p><a href="https://arxiv.org/abs/2202.00858">📄 Paper (ICML 2022)</a>, <a href="https://csinva.io/imodels/tree/hierarchical_shrinkage.html">🗂 Doc</a>, <a href="https://scholar.google.com/scholar?hl=en&as_sdt=0%2C5&q=hierarchical+shrinkage+singh&btnG=&oq=hierar#d=gs_cit&u=%2Fscholar%3Fq%3Dinfo%3Azc6gtLx-aL4J%3Ascholar.google.com%2F%26output%3Dcite%26scirp%3D0%26hl%3Den">📌 Citation</a></p>
  35. Hierarchical shrinkage is an extremely fast post-hoc regularization method which works on any decision tree (or tree-based ensemble, such as Random Forest). It does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors (using a single regularization parameter). Experiments over a wide variety of datasets show that hierarchical shrinkage substantially increases the predictive performance of individual decision trees and decision-tree ensembles.
  36. <h2>How does Hierarchical shrinkage work? </h2>
  37. <p align="center">
  38. <img src="https://demos.csinva.io/shrinkage/shrinkage_intro.svg?sanitize=True" width="90%">
  39. </p>
  40. <p align="center">
  41. <i><b>Fig 1.</b> HS applies post-hoc regularization to any decision tree by shrinking each node towards its parent. This is done after a tree has been trained. The amount of shrinkage can be varied using a regularization param
  42. (this works best if the parameter is chosen via cross-validation). </i>
  43. </p>
  44. <h2>An example using HS</h2>
  45. HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the <code>fit</code> and <code>predict</code> methods.
  46. Here's a full example of using it on a sample clinical dataset.
  47. <pre>
  48. <code>
  49. from imodels import HSTreeClassifierCV, get_clean_dataset
  50. from sklearn.model_selection import train_test_split
  51. from sklearn.tree import plot_tree
  52. # prepare data (in this a sample clinical dataset)
  53. X, y, feat_names = get_clean_dataset('csi_pecarn_pred')
  54. X_train, X_test, y_train, y_test = train_test_split(
  55. X, y, test_size=0.33, random_state=42)
  56. # fit the model
  57. model = HSTreeClassifierCV(max_leaf_nodes=7) # initialize a model
  58. model.fit(X_train, y_train) # fit model
  59. preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1)
  60. preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes)
  61. # visualize the model
  62. plot_tree(model.estimator_, feature_names=feat_names)
  63. </code>
  64. </pre>
  65. Here we used <code>HSTreeClassifierCV</code>, which selects the amount of regularization to use via cross-validation, but we can also use <code>HSTreeClassifier</code> if we want to specify a particular amount of regularization.
  66. For regression, we can use the corresponding classes: <code>HSTreeRegressorCV</code> and <code>HSTreeRegressor</code>.
  67. <p style="text-align:center;">
  68. <a href="https://github.com/csinva/imodels"><img src="https://demos.csinva.io/shrinkage/shrinkage_csi_model.svg?sanitize=True" width="100%"></a>
  69. <br>
  70. <i><b>Fig 2.</b> Simple model learned by HS for predicting risk of cervical spinal injury. </i>
  71. </p>
  72. <h2>Examples with HS on synthetic data</h2>
  73. <p>See some examples of how hierarchical shrinkage works on one-dimensional functions which are fitted with a CART decision tree.</p>
  74. <!-- <div class="justify-content-center" align="center">-->
  75. <!-- <div class="col-lg-8">-->
  76. <!-- <h1> Tree shrinkage visualizations </h1>-->
  77. <div id="4db2fbb0-073d-492f-8f30-3e9fe37bf9f2" class="plotly-graph-div"
  78. style="height:60%; width:100%;"></div>
  79. <script type="text/javascript" src="https://demos.csinva.io/shrinkage/shrinkage_steps.js"></script>
  80. <p style="text-align:center;">
  81. <i><b>Fig 3.</b> Step function. </i>
  82. </p>
  83. <!-- <link rel="stylesheet"-->
  84. <!-- href="https://cdnjs.cloudflare.com/ajax/libs/startbootstrap-clean-blog/5.0.10/css/clean-blog.min.css">-->
  85. <div id="ae4f9428-3c0a-4f94-a17d-6dbbc72c7d38" class="plotly-graph-div"
  86. style="height:60%; width:100%;"></div>
  87. <script type="text/javascript" src="https://demos.csinva.io/shrinkage/shrinkage_linear.js"></script>
  88. <!-- </div>-->
  89. <p style="text-align:center;">
  90. <i><b>Fig 4.</b> Linear function. </i>
  91. </p>
  92. </div>
  93. <h2>Applying HS to tree ensembles</h2>
  94. HS can also be used on tree ensembles to regularize each tree in an ensemble (e.g. in a Random Forest).
  95. We must simply pass the desired estimator during initialization.
  96. <pre>
  97. <code>
  98. from sklearn.ensemble import RandomForestClassifier # also works with ExtraTreesClassifier, GradientBoostingClassifier
  99. from imodels import HSTreeClassifier
  100. ensemble = RandomForestClassifier()
  101. model = HSTreeClassifier(estimator_=ensemble)
  102. model = model.fit(X_train, y_train)
  103. </code>
  104. </pre>
  105. </section>
  106. <section>
  107. </section>
  108. <section>
  109. </section>
  110. <section>
  111. </section>
  112. </article>
  113. <nav id="sidebar">
  114. <h1>Index 🔍</h1>
  115. <div class="toc">
  116. <ul>
  117. <li><a href="index.html#installation">Installation</a></li>
  118. <li><a href="index.html#supported-models">Supported models</a></li>
  119. <li><a href="index.html#whats-the-difference-between-the-models">What's the difference between the models?</a></li>
  120. <li><a href="index.html#demo-notebooks">Demo notebooks</a></li>
  121. <li><a href="index.html#support-for-different-tasks">Support for different tasks</a><ul>
  122. <li><a href="index.html#extras">Extras</a></li>
  123. </ul>
  124. </li><li><a href="#references">References</a></li>
  125. </ul>
  126. </div>
  127. <ul id="index">
  128. <li><h3>Our favorite models</h3>
  129. <ul>
  130. <li><a href="https://csinva.io/imodels/shrinkage.html">Hierarchical shrinkage: post-hoc regularization for tree-based methods</a></li>
  131. <li><a href="https://csinva.io/imodels/figs.html">FIGS: Fast interpretable greedy-tree sums</a></li>
  132. </ul>
  133. </li>
  134. <li><h3><a href="#header-submodules">Sub-modules</a></h3>
  135. <ul>
  136. <li><code><a title="imodels.algebraic" href="algebraic/index.html">imodels.algebraic</a></code></li>
  137. <li><code><a title="imodels.discretization" href="discretization/index.html">imodels.discretization</a></code></li>
  138. <li><code><a title="imodels.experimental" href="experimental/index.html">imodels.experimental</a></code></li>
  139. <li><code><a title="imodels.rule_list" href="rule_list/index.html">imodels.rule_list</a></code></li>
  140. <li><code><a title="imodels.rule_set" href="rule_set/index.html">imodels.rule_set</a></code></li>
  141. <li><code><a title="imodels.tree" href="tree/index.html">imodels.tree</a></code></li>
  142. <li><code><a title="imodels.util" href="util/index.html">imodels.util</a></code></li>
  143. </ul>
  144. </li>
  145. </ul>
  146. <p><img align="center" width=100% src="https://csinva.io/imodels/img/anim.gif"> </img></p>
  147. <!-- add wave animation -->
  148. </nav>
  149. </main>
  150. <footer id="footer">
  151. </footer>
  152. </body>
  153. </html>
  154. <!-- add github corner -->
  155. <a href="https://github.com/csinva/imodels" class="github-corner" aria-label="View source on GitHub"><svg width="120" height="120" viewBox="0 0 250 250" style="fill:#70B7FD; color:#fff; position: absolute; top: 0; border: 0; right: 0;" aria-hidden="true"><path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path><path d="m128.3,109.0 c113.8,99.7 119.0,89.6 119.0,89.6 c122.0,82.7 120.5,78.6 120.5,78.6 c119.2,72.0 123.4,76.3 123.4,76.3 c127.3,80.9 125.5,87.3 125.5,87.3 c122.9,97.6 130.6,101.9 134.4,103.2" fill="currentcolor" style="transform-origin: 130px 106px;" class="octo-arm"></path><path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z" fill="currentColor" class="octo-body"></path></svg></a><style>.github-corner:hover .octo-arm{animation:octocat-wave 560ms ease-in-out}@keyframes octocat-wave{0%,100%{transform:rotate(0)}20%,60%{transform:rotate(-25deg)}40%,80%{transform:rotate(10deg)}}@media (max-width:500px){.github-corner:hover .octo-arm{animation:none}.github-corner .octo-arm{animation:octocat-wave 560ms ease-in-out}}</style>
  156. <!-- add wave animation stylesheet -->
  157. <!--<link href="wave.css" rel="stylesheet">-->
  158. <link rel="stylesheet" href="github.css">
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...