Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

figs.html 17 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
  1. <!doctype html>
  2. <html lang="en">
  3. <head>
  4. <!-- <script src="jquery-3.5.1.min.js"></script>-->
  5. <script src="https://cdn.plot.ly/plotly-2.6.3.min.js"></script>
  6. <script type="text/javascript">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>
  7. <meta charset="utf-8">
  8. <meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
  9. <meta name="generator" content="pdoc 0.10.0" />
  10. <link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css" integrity="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs=" crossorigin>
  11. <link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css" integrity="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg=" crossorigin>
  12. <link rel="stylesheet preload" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css" crossorigin>
  13. <style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
  14. <style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
  15. <style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
  16. <script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js" integrity="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8=" crossorigin></script>
  17. <script>window.addEventListener('DOMContentLoaded', () => hljs.initHighlighting())</script>
  18. <link rel="stylesheet" href="https://demos.csinva.io/figs/style.css">
  19. <script type="text/javascript">
  20. function drawgif() {
  21. var wrapper = document.querySelector('svg')
  22. wrapper.classList.add('active')
  23. }
  24. </script>
  25. </head>
  26. <body>
  27. <main>
  28. <article id="content">
  29. <section id="section-intro">
  30. <p align="center">
  31. <img align="center" width=60% src="https://csinva.io/imodels/img/imodels_logo.svg?sanitize=True&kill_cache=1"> </img>
  32. <br/>
  33. <div class="article" style="padding-right: 2%; padding-left: 2%;">
  34. <p align="center">
  35. <h1 text-align="center" style="display: inline;padding-bottom: 0px;"> FIGS: Fast Interpretable Greedy-Tree
  36. Sums </h1>
  37. </p>
  38. <hr>
  39. <nav class="nav">
  40. <!-- <button class="js-reset butt">Reset</button>-->
  41. <!-- <div class="butt disabled"><span class="js-count">0</span>/<span class="js-total">100</div>-->
  42. </nav>
  43. <div id="p5" class="container">
  44. <p class="overlay">Click and drag for more FIGS</p>
  45. </div>
  46. <!-- partial -->
  47. <script src='https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.7.3/p5.min.js'></script>
  48. <script src='https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.7.3/addons/p5.dom.min.js'></script>
  49. <script src='https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.7.3/addons/p5.sound.min.js'></script>
  50. <script src='https://unpkg.com/matter-js@0.14.2/build/matter.min.js'></script>
  51. <script src="https://demos.csinva.io/figs/script.js"></script>
  52. <p><a href="https://arxiv.org/abs/2202.00858">📄 Paper</a>, <a href="https://csinva.io/imodels/tree/figs.html#imodels.tree.figs">🗂 Doc</a>, <a href="https://scholar.google.com/scholar?hl=en&as_sdt=0%2C5&q=fast+interpretable+greedy-tree+sums&oq=fast#d=gs_cit&u=%2Fscholar%3Fq%3Dinfo%3ADnPVL74Rop0J%3Ascholar.google.com%2F%26output%3Dcite%26scirp%3D0%26hl%3Den">📌 Citation</a></p>
  53. <h3 style="color:gray;padding-top:0px;">Yan Shuo Tan*, Chandan Singh*, Keyan Nasseri*, Abhineet Agarwal, James Duncan, Omer Ronen, Aaron Kornblith, Bin Yu</h3>
  54. <p>
  55. Modern machine learning has achieved impressive prediction performance, but often sacrifices interpretability, a
  56. critical consideration in many problems.
  57. Here, we propose Fast Interpretable Greedy-Tree Sums (FIGS), an algorithm for fitting concise rule-based
  58. models.
  59. Specifically, FIGS generalizes the CART algorithm to work on sums of trees, growing a flexible number of
  60. them simultaneously.
  61. The total number of splits across all the trees is restricted by a pre-specified threshold, which ensures that
  62. FIGS remains interpretable.
  63. Extensive experiments show that FIGS achieves state-of-the-art performance across a wide array of real-world
  64. datasets when restricted to very few splits (e.g. less than 20).
  65. Theoretical and simulation results suggest that FIGS overcomes a key weakness of single-tree models by
  66. disentangling additive components of generative additive models, thereby significantly improving convergence
  67. rates for l2 generalization error.
  68. We further characterize the success of FIGS by quantifying how it reduces repeated splits, which can lead to
  69. redundancy in single-tree models such as CART.
  70. All code and models are released in a full-fledged package available on Github.
  71. </p>
  72. <h2>How does FIGS work? </h2>
  73. Intuitively, FIGS works by extending CART, a typical greedy algorithm for growing a decision tree, to consider growing a <i>sum</i> of trees <i>simultaneously</i> (see Fig 1). At each iteration, FIGS may grow any existing tree it has already started or start a new tree; it greedily selects whichever rule reduces the total unexplained variance (or an alternative splitting criterion) the most. To keep the trees in sync with one another, each tree is made to predict the <i>residuals</i> remaining after summing the predictions of all other trees.
  74. <br>
  75. <br>
  76. FIGS is intuitively similar to ensemble approaches such as gradient boosting / random forest, but importantly since all trees are grown to compete with each other the model can adapt more to the underlying structure in the data. The number of trees and size/shape of each tree emerge automatically from the data rather than being manually specified.
  77. <p style="text-align:center;">
  78. <a href="https://github.com/csinva/imodels"><img src="https://demos.csinva.io/figs/figs_fitting.gif?sanitize=True" width="90%"></a>
  79. <br>
  80. <b>Fig 1. </b><i>High-level intuition for how FIGS fits a model.</i>
  81. </p>
  82. <h2>An example using FIGS</h2>
  83. FIGS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the <code>fit</code> and <code>predict</code> methods. Here's a full example of using it on a sample clinical dataset.
  84. <pre>
  85. <code>
  86. from imodels import FIGSClassifier, get_clean_dataset
  87. from sklearn.model_selection import train_test_split
  88. # prepare data (in this a sample clinical dataset)
  89. X, y, feat_names = get_clean_dataset('csi_pecarn_pred')
  90. X_train, X_test, y_train, y_test = train_test_split(
  91. X, y, test_size=0.33, random_state=42)
  92. # fit the model
  93. model = FIGSClassifier(max_rules=4) # initialize a model
  94. model.fit(X_train, y_train) # fit model
  95. preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1)
  96. preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes)
  97. # visualize the model
  98. model.plot(feature_names=feat_names, filename='out.svg', dpi=300)
  99. </code>
  100. </pre>
  101. This results in a simple model -- it contains only 4 splits (since we specified that the model should have no more than 4 splits (<code>max_rules=4</code>). Predictions are made by summing the value obtained from the appropriate leaf of each tree. This model is extremely interpretable, as a physician can now (i) easily make predictions using the 4 relevant features and (ii) vet the model to ensure it matches their domain expertise. Note that this model is just for illustration purposes, and achieves ~84% accuracy.
  102. <p style="text-align:center;">
  103. <a href="https://github.com/csinva/imodels"><img src="https://demos.csinva.io/figs/figs_csi_model_small.svg?sanitize=True" width="85%"></a>
  104. <br>
  105. <i><b>Fig 2.</b> Simple model learned by FIGS for predicting risk of cervical spinal injury. </i>
  106. </p>
  107. If we want a more flexible model, we can also remove the constraint on the number of rules (changing the code to <code>model = FIGSClassifier()</code>), resulting in a larger model (see Fig 3). Note that the number of trees and how balanced they are emerges from the structure of the data -- only the total number of rules may be specified.
  108. <p style="text-align:center;">
  109. <a href="https://github.com/csinva/imodels"><img src="https://demos.csinva.io/figs/figs_csi_model_large.svg?sanitize=True" width="100%"></a>
  110. <br>
  111. <i><b>Fig 3.</b> Slightly larger model learned by FIGS for predicting risk of cervical spinal injury. </i>
  112. </p>
  113. <h2>Another example of using FIGS</h2>
  114. <p>
  115. Here, we examine the <a href="https://www.sciencedirect.com/science/article/pii/S0140673671923038">Diabetes
  116. classification dataset</a>, in which eight risk factors were collected and used to predict the onset of diabetes
  117. within 5 five years. Fitting, several models we find that with very few rules, the model can achieve excellent
  118. test performance.
  119. </p>
  120. <p>
  121. For example, Fig 2 shows a model fitted using the FIGS algorithm which achieves a test-AUC of 0.820 despite
  122. being extremely simple. In this model, each feature contributes independently of the others, and the final risks
  123. from each of three key features is summed to get a risk for the onset of diabetes (higher is higher risk). As
  124. opposed to a black-box model, this model is easy to interpret, fast to compute with, and allows us to vet the
  125. features being used for decision-making.
  126. </p>
  127. <p style="text-align:center;">
  128. <img width=60% src="https://demos.csinva.io/figs/diabetes_figs.svg" title="diabetes FIGS"></img>
  129. <br>
  130. <i><b>Fig 2.</b> Simple model learned by <a href="">FIGS</a> for diabetes risk prediction. </i>
  131. </p>
  132. </div>
  133. </section>
  134. <section>
  135. </section>
  136. <section>
  137. </section>
  138. <section>
  139. </section>
  140. </article>
  141. <nav id="sidebar">
  142. <h1>Index 🔍</h1>
  143. <div class="toc">
  144. <ul>
  145. <li><a href="index.html#installation">Installation</a></li>
  146. <li><a href="index.html#supported-models">Supported models</a></li>
  147. <li><a href="index.html#whats-the-difference-between-the-models">What's the difference between the models?</a></li>
  148. <li><a href="index.html#demo-notebooks">Demo notebooks</a></li>
  149. <li><a href="index.html#support-for-different-tasks">Support for different tasks</a><ul>
  150. <li><a href="index.html#extras">Extras</a></li>
  151. </ul>
  152. </li><li><a href="#references">References</a></li>
  153. </ul>
  154. </div>
  155. <ul id="index">
  156. <li><h3>Our favorite models</h3>
  157. <ul>
  158. <li><a href="https://csinva.io/imodels/shrinkage.html">Hierarchical shrinkage: post-hoc regularization for tree-based methods</a></li>
  159. <li><a href="https://csinva.io/imodels/figs.html">FIGS: Fast interpretable greedy-tree sums</a></li>
  160. </ul>
  161. </li>
  162. <li><h3><a href="#header-submodules">Sub-modules</a></h3>
  163. <ul>
  164. <li><code><a title="imodels.algebraic" href="algebraic/index.html">imodels.algebraic</a></code></li>
  165. <li><code><a title="imodels.discretization" href="discretization/index.html">imodels.discretization</a></code></li>
  166. <li><code><a title="imodels.experimental" href="experimental/index.html">imodels.experimental</a></code></li>
  167. <li><code><a title="imodels.rule_list" href="rule_list/index.html">imodels.rule_list</a></code></li>
  168. <li><code><a title="imodels.rule_set" href="rule_set/index.html">imodels.rule_set</a></code></li>
  169. <li><code><a title="imodels.tree" href="tree/index.html">imodels.tree</a></code></li>
  170. <li><code><a title="imodels.util" href="util/index.html">imodels.util</a></code></li>
  171. </ul>
  172. </li>
  173. </ul>
  174. <p><img align="center" width=100% src="https://csinva.io/imodels/img/anim.gif"> </img></p>
  175. <!-- add wave animation -->
  176. </nav>
  177. </main>
  178. <footer id="footer">
  179. </footer>
  180. </body>
  181. </html>
  182. <!-- add github corner -->
  183. <a href="https://github.com/csinva/imodels" class="github-corner" aria-label="View source on GitHub"><svg width="120" height="120" viewBox="0 0 250 250" style="fill:#70B7FD; color:#fff; position: absolute; top: 0; border: 0; right: 0;" aria-hidden="true"><path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path><path d="m128.3,109.0 c113.8,99.7 119.0,89.6 119.0,89.6 c122.0,82.7 120.5,78.6 120.5,78.6 c119.2,72.0 123.4,76.3 123.4,76.3 c127.3,80.9 125.5,87.3 125.5,87.3 c122.9,97.6 130.6,101.9 134.4,103.2" fill="currentcolor" style="transform-origin: 130px 106px;" class="octo-arm"></path><path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z" fill="currentColor" class="octo-body"></path></svg></a><style>.github-corner:hover .octo-arm{animation:octocat-wave 560ms ease-in-out}@keyframes octocat-wave{0%,100%{transform:rotate(0)}20%,60%{transform:rotate(-25deg)}40%,80%{transform:rotate(10deg)}}@media (max-width:500px){.github-corner:hover .octo-arm{animation:none}.github-corner .octo-arm{animation:octocat-wave 560ms ease-in-out}}</style>
  184. <!-- add wave animation stylesheet -->
  185. <!--<link href="wave.css" rel="stylesheet">-->
  186. <link rel="stylesheet" href="github.css">
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...