Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

mdi_plus.py 20 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
  1. import numpy as np
  2. import pandas as pd
  3. from scipy.spatial.distance import pdist
  4. from functools import partial
  5. from .ppms import PartialPredictionModelBase, GenericRegressorPPM, GenericClassifierPPM
  6. from .block_transformers import _blocked_train_test_split
  7. from .ranking_stability import tauAP_b, rbo
  8. class ForestMDIPlus:
  9. """
  10. The class object for computing MDI+ feature importances for a forest or collection of trees.
  11. Generalized mean decrease in impurity (MDI+) is a flexible framework for computing RF
  12. feature importances. For more details, refer to [paper].
  13. Parameters
  14. ----------
  15. estimators: list of fitted PartialPredictionModelBase objects or scikit-learn type estimators
  16. The fitted partial prediction models (one per tree) to use for evaluating
  17. feature importance via MDI+. If not a PartialPredictionModelBase, then
  18. the estimator is coerced into a PartialPredictionModelBase object via
  19. GenericRegressorPPM or GenericClassifierPPM depending on the specified
  20. task. Note that these generic PPMs may be computationally expensive.
  21. transformers: list of BlockTransformerBase objects
  22. The block feature transformers used to generate blocks of engineered
  23. features for each original feature. The transformed data is then used
  24. as input into the partial prediction models. Should be the same length
  25. as estimators.
  26. scoring_fns: a function or dict with functions as value and function name (str) as key
  27. The scoring functions used for evaluating the partial predictions.
  28. sample_split: string in {"loo", "oob", "inbag"} or None
  29. The sample splitting strategy to be used when evaluating the partial
  30. model predictions. The default "loo" (leave-one-out) is strongly
  31. recommended for performance and in particular, for overcoming the known
  32. correlation and entropy biases suffered by MDI. "oob" (out-of-bag) can
  33. also be used to overcome these biases. "inbag" is the sample splitting
  34. strategy used by MDI. If None, no sample splitting is performed and the
  35. full data set is used to evaluate the partial model predictions.
  36. tree_random_states: list of int or None
  37. Random states from each tree in the fitted random forest; used in
  38. sample splitting and only required if sample_split = "oob" or "inbag".
  39. Should be the same length as estimators.
  40. mode: string in {"keep_k", "keep_rest"}
  41. Mode for the method. "keep_k" imputes the mean of each feature not
  42. in block k when making a partial model prediction, while "keep_rest"
  43. imputes the mean of each feature in block k. "keep_k" is strongly
  44. recommended for computational considerations.
  45. task: string in {"regression", "classification"}
  46. The supervised learning task for the RF model. Used for choosing
  47. defaults for the scoring_fns. Currently only regression and
  48. classification are supported.
  49. center: bool
  50. Flag for whether to center the transformed data in the transformers.
  51. normalize: bool
  52. Flag for whether to rescale the transformed data to have unit
  53. variance in the transformers.
  54. """
  55. def __init__(self, estimators, transformers, scoring_fns,
  56. sample_split="loo", tree_random_states=None, mode="keep_k",
  57. task="regression", center=True, normalize=False):
  58. assert sample_split in ["loo", "oob", "inbag", None]
  59. assert mode in ["keep_k", "keep_rest"]
  60. assert task in ["regression", "classification"]
  61. self.estimators = estimators
  62. self.transformers = transformers
  63. self.scoring_fns = scoring_fns
  64. self.sample_split = sample_split
  65. self.tree_random_states = tree_random_states
  66. if self.sample_split in ["oob", "inbag"] and not self.tree_random_states:
  67. raise ValueError("Must specify tree_random_states to use 'oob' or 'inbag' sample_split.")
  68. self.mode = mode
  69. self.task = task
  70. self.center = center
  71. self.normalize = normalize
  72. self.is_fitted = False
  73. self.prediction_score_ = pd.DataFrame({})
  74. self.feature_importances_ = pd.DataFrame({})
  75. self.feature_importances_by_tree_ = {}
  76. def get_scores(self, X, y):
  77. """
  78. Obtain the MDI+ feature importances for a forest.
  79. Parameters
  80. ----------
  81. X: ndarray of shape (n_samples, n_features)
  82. The covariate matrix. If a pd.DataFrame object is supplied, then
  83. the column names are used in the output
  84. y: ndarray of shape (n_samples, n_targets)
  85. The observed responses.
  86. Returns
  87. -------
  88. scores: pd.DataFrame of shape (n_features, n_scoring_fns)
  89. The MDI+ feature importances.
  90. """
  91. self._fit_importance_scores(X, y)
  92. return self.feature_importances_
  93. def get_stability_scores(self, B=10, metrics="auto"):
  94. """
  95. Evaluate the stability of the MDI+ feature importance rankings
  96. across bootstrapped samples of trees. Can be used to select the GLM
  97. and scoring metric in a data-driven manner, where the GLM and metric that
  98. yields the most stable feature rankings across bootstrapped samples is selected.
  99. Parameters
  100. ----------
  101. B: int
  102. Number of bootstrap samples.
  103. metrics: "auto" or a dict with functions as value and function name (str) as key
  104. Metric(s) used to evaluate the stability between two sets of feature importances.
  105. If "auto", then the feature importance stability metrics are:
  106. (1) Rank-based overlap (RBO) with p=0.9 (from "A Similarity Measure for
  107. Indefinite Rankings" by Webber et al. (2010)). Intuitively, this metric gives
  108. more weight to features with the largest importances, with most of the weight
  109. going to the ~1/(1-p) features with the largest importances.
  110. (2) A weighted kendall tau metric (tauAP_b from "The Treatment of Ties in
  111. AP Correlation" by Urbano and Marrero (2017)), which also gives more weight
  112. to the features with the largest importances, but uses a different weighting
  113. scheme from RBO.
  114. Note that these default metrics assume that a higher MDI+ score indicates
  115. greater importance and thus give more weight to these features with high
  116. importance/ranks. If a lower MDI+ score indicates higher importance, then invert
  117. either these stability metrics or the MDI+ scores before evaluating the stability.
  118. Returns
  119. -------
  120. stability_results: pd.DataFrame of shape (n_features, n_metrics)
  121. The stability scores of the MDI+ feature rankings across bootstrapped samples.
  122. """
  123. if metrics == "auto":
  124. metrics = {"RBO": partial(rbo, p=0.9), "tauAP": tauAP_b}
  125. elif not isinstance(metrics, dict):
  126. raise ValueError("`metrics` must be 'auto' or a dictionary "
  127. "where the key is the metric name and the value is the evaluation function")
  128. single_scoring_fn = not isinstance(self.feature_importances_by_tree_, dict)
  129. if single_scoring_fn:
  130. feature_importances_dict = {"mdi_plus_score": self.feature_importances_by_tree_}
  131. else:
  132. feature_importances_dict = self.feature_importances_by_tree_
  133. stability_dict = {}
  134. for scoring_fn_name, feature_importances_by_tree in feature_importances_dict.items():
  135. n_trees = feature_importances_by_tree.shape[1]
  136. fi_scores_boot_ls = []
  137. for b in range(B):
  138. bootstrap_sample = np.random.choice(n_trees, n_trees, replace=True)
  139. fi_scores_boot_ls.append(feature_importances_by_tree[bootstrap_sample].mean(axis=1))
  140. fi_scores_boot = pd.concat(fi_scores_boot_ls, axis=1)
  141. stability_results = {"scorer": [scoring_fn_name]}
  142. for metric_name, metric_fun in metrics.items():
  143. stability_results[metric_name] = [np.mean(pdist(fi_scores_boot.T, metric=metric_fun))]
  144. stability_dict[scoring_fn_name] = pd.DataFrame(stability_results)
  145. stability_df = pd.concat(stability_dict, axis=0).reset_index(drop=True)
  146. if single_scoring_fn:
  147. stability_df = stability_df.drop(columns=["scorer"])
  148. return stability_df
  149. def _fit_importance_scores(self, X, y):
  150. all_scores = []
  151. all_full_preds = []
  152. for estimator, transformer, tree_random_state in \
  153. zip(self.estimators, self.transformers, self.tree_random_states):
  154. tree_mdi_plus = TreeMDIPlus(estimator=estimator,
  155. transformer=transformer,
  156. scoring_fns=self.scoring_fns,
  157. sample_split=self.sample_split,
  158. tree_random_state=tree_random_state,
  159. mode=self.mode,
  160. task=self.task,
  161. center=self.center,
  162. normalize=self.normalize)
  163. scores = tree_mdi_plus.get_scores(X, y)
  164. if scores is not None:
  165. all_scores.append(scores)
  166. all_full_preds.append(tree_mdi_plus._full_preds)
  167. if len(all_scores) == 0:
  168. raise ValueError("Transformer representation was empty for all trees.")
  169. full_preds = np.nanmean(all_full_preds, axis=0)
  170. self._full_preds = full_preds
  171. scoring_fns = self.scoring_fns if isinstance(self.scoring_fns, dict) \
  172. else {"importance": self.scoring_fns}
  173. for fn_name, scoring_fn in scoring_fns.items():
  174. self.feature_importances_by_tree_[fn_name] = pd.concat([scores[fn_name] for scores in all_scores], axis=1)
  175. self.feature_importances_by_tree_[fn_name].columns = np.arange(len(all_scores))
  176. self.feature_importances_[fn_name] = np.mean(self.feature_importances_by_tree_[fn_name], axis=1)
  177. self.prediction_score_[fn_name] = [scoring_fn(y[~np.isnan(full_preds)], full_preds[~np.isnan(full_preds)])]
  178. if list(scoring_fns.keys()) == ["importance"]:
  179. self.prediction_score_ = self.prediction_score_["importance"]
  180. self.feature_importances_by_tree_ = self.feature_importances_by_tree_["importance"]
  181. if isinstance(X, pd.DataFrame):
  182. self.feature_importances_.index = X.columns
  183. self.feature_importances_.index.name = 'var'
  184. self.feature_importances_.reset_index(inplace=True)
  185. self.is_fitted = True
  186. class TreeMDIPlus:
  187. """
  188. The class object for computing MDI+ feature importances for a single tree.
  189. Generalized mean decrease in impurity (MDI+) is a flexible framework for computing RF
  190. feature importances. For more details, refer to [paper].
  191. Parameters
  192. ----------
  193. estimator: a fitted PartialPredictionModelBase object or scikit-learn type estimator
  194. The fitted partial prediction model to use for evaluating
  195. feature importance via MDI+. If not a PartialPredictionModelBase, then
  196. the estimator is coerced into a PartialPredictionModelBase object via
  197. GenericRegressorPPM or GenericClassifierPPM depending on the specified
  198. task. Note that these generic PPMs may be computationally expensive.
  199. transformer: a BlockTransformerBase object
  200. A block feature transformer used to generate blocks of engineered
  201. features for each original feature. The transformed data is then used
  202. as input into the partial prediction models.
  203. scoring_fns: a function or dict with functions as value and function name (str) as key
  204. The scoring functions used for evaluating the partial predictions.
  205. sample_split: string in {"loo", "oob", "inbag"} or None
  206. The sample splitting strategy to be used when evaluating the partial
  207. model predictions. The default "loo" (leave-one-out) is strongly
  208. recommended for performance and in particular, for overcoming the known
  209. correlation and entropy biases suffered by MDI. "oob" (out-of-bag) can
  210. also be used to overcome these biases. "inbag" is the sample splitting
  211. strategy used by MDI. If None, no sample splitting is performed and the
  212. full data set is used to evaluate the partial model predictions.
  213. tree_random_state: int or None
  214. Random state of the fitted tree; used in sample splitting and
  215. only required if sample_split = "oob" or "inbag".
  216. mode: string in {"keep_k", "keep_rest"}
  217. Mode for the method. "keep_k" imputes the mean of each feature not
  218. in block k when making a partial model prediction, while "keep_rest"
  219. imputes the mean of each feature in block k. "keep_k" is strongly
  220. recommended for computational considerations.
  221. task: string in {"regression", "classification"}
  222. The supervised learning task for the RF model. Used for choosing
  223. defaults for the scoring_fns. Currently only regression and
  224. classification are supported.
  225. center: bool
  226. Flag for whether to center the transformed data in the transformers.
  227. normalize: bool
  228. Flag for whether to rescale the transformed data to have unit
  229. variance in the transformers.
  230. """
  231. def __init__(self, estimator, transformer, scoring_fns,
  232. sample_split="loo", tree_random_state=None, mode="keep_k",
  233. task="regression", center=True, normalize=False):
  234. assert sample_split in ["loo", "oob", "inbag", "auto", None]
  235. assert mode in ["keep_k", "keep_rest"]
  236. assert task in ["regression", "classification"]
  237. self.estimator = estimator
  238. self.transformer = transformer
  239. self.scoring_fns = scoring_fns
  240. self.sample_split = sample_split
  241. self.tree_random_state = tree_random_state
  242. _validate_sample_split(self.sample_split, self.estimator, isinstance(self.estimator, PartialPredictionModelBase))
  243. if self.sample_split in ["oob", "inbag"] and not self.tree_random_state:
  244. raise ValueError("Must specify tree_random_state to use 'oob' or 'inbag' sample_split.")
  245. self.mode = mode
  246. self.task = task
  247. self.center = center
  248. self.normalize = normalize
  249. self.is_fitted = False
  250. self._full_preds = None
  251. self.prediction_score_ = None
  252. self.feature_importances_ = None
  253. def get_scores(self, X, y):
  254. """
  255. Obtain the MDI+ feature importances for a single tree.
  256. Parameters
  257. ----------
  258. X: ndarray of shape (n_samples, n_features)
  259. The covariate matrix. If a pd.DataFrame object is supplied, then
  260. the column names are used in the output
  261. y: ndarray of shape (n_samples, n_targets)
  262. The observed responses.
  263. Returns
  264. -------
  265. scores: pd.DataFrame of shape (n_features, n_scoring_fns)
  266. The MDI+ feature importances.
  267. """
  268. self._fit_importance_scores(X, y)
  269. return self.feature_importances_
  270. def _fit_importance_scores(self, X, y):
  271. n_samples = y.shape[0]
  272. blocked_data = self.transformer.transform(X, center=self.center,
  273. normalize=self.normalize)
  274. self.n_features = blocked_data.n_blocks
  275. train_blocked_data, test_blocked_data, y_train, y_test, test_indices = \
  276. _get_sample_split_data(blocked_data, y, self.tree_random_state, self.sample_split)
  277. if train_blocked_data.get_all_data().shape[1] != 0:
  278. if hasattr(self.estimator, "predict_full") and \
  279. hasattr(self.estimator, "predict_partial"):
  280. full_preds = self.estimator.predict_full(test_blocked_data)
  281. partial_preds = self.estimator.predict_partial(test_blocked_data, mode=self.mode)
  282. else:
  283. if self.task == "regression":
  284. ppm = GenericRegressorPPM(self.estimator)
  285. elif self.task == "classification":
  286. ppm = GenericClassifierPPM(self.estimator)
  287. full_preds = ppm.predict_full(test_blocked_data)
  288. partial_preds = ppm.predict_partial(test_blocked_data, mode=self.mode)
  289. self._score_full_predictions(y_test, full_preds)
  290. self._score_partial_predictions(y_test, full_preds, partial_preds)
  291. full_preds_n = np.empty(n_samples) if full_preds.ndim == 1 \
  292. else np.empty((n_samples, full_preds.shape[1]))
  293. full_preds_n[:] = np.nan
  294. full_preds_n[test_indices] = full_preds
  295. self._full_preds = full_preds_n
  296. self.is_fitted = True
  297. def _score_full_predictions(self, y_test, full_preds):
  298. scoring_fns = self.scoring_fns if isinstance(self.scoring_fns, dict) \
  299. else {"score": self.scoring_fns}
  300. all_prediction_scores = pd.DataFrame({})
  301. for fn_name, scoring_fn in scoring_fns.items():
  302. scores = scoring_fn(y_test, full_preds)
  303. all_prediction_scores[fn_name] = [scores]
  304. self.prediction_score_ = all_prediction_scores
  305. def _score_partial_predictions(self, y_test, full_preds, partial_preds):
  306. scoring_fns = self.scoring_fns if isinstance(self.scoring_fns, dict) \
  307. else {"importance": self.scoring_fns}
  308. all_scores = pd.DataFrame({})
  309. for fn_name, scoring_fn in scoring_fns.items():
  310. scores = _partial_preds_to_scores(partial_preds, y_test, scoring_fn)
  311. if self.mode == "keep_rest":
  312. full_score = scoring_fn(y_test, full_preds)
  313. scores = full_score - scores
  314. if len(partial_preds) != scores.size:
  315. if len(scoring_fns) > 1:
  316. msg = "scoring_fn={} should return one value for each feature.".format(fn_name)
  317. else:
  318. msg = "scoring_fns should return one value for each feature.".format(fn_name)
  319. raise ValueError("Unexpected dimensions. {}".format(msg))
  320. scores = scores.ravel()
  321. all_scores[fn_name] = scores
  322. self.feature_importances_ = all_scores
  323. def _partial_preds_to_scores(partial_preds, y_test, scoring_fn):
  324. scores = []
  325. for k, y_pred in partial_preds.items():
  326. if isinstance(y_pred, tuple): # if constant model
  327. y_pred = np.ones_like(y_test) * y_pred[1]
  328. scores.append(scoring_fn(y_test, y_pred))
  329. return np.vstack(scores)
  330. def _get_default_sample_split(sample_split, prediction_model, is_ppm):
  331. if sample_split == "auto":
  332. sample_split = "oob"
  333. if is_ppm:
  334. if prediction_model.loo:
  335. sample_split = "loo"
  336. return sample_split
  337. def _validate_sample_split(sample_split, prediction_model, is_ppm):
  338. if sample_split in ["oob", "inbag"] and is_ppm:
  339. if prediction_model.loo:
  340. raise ValueError("Cannot use LOO together with OOB or in-bag sample splitting.")
  341. def _get_sample_split_data(blocked_data, y, random_state, sample_split):
  342. if sample_split == "oob":
  343. train_blocked_data, test_blocked_data, y_train, y_test, _, test_indices = \
  344. _blocked_train_test_split(blocked_data, y, random_state)
  345. elif sample_split == "inbag":
  346. train_blocked_data, _, y_train, _, test_indices, _ = \
  347. _blocked_train_test_split(blocked_data, y, random_state)
  348. test_blocked_data = train_blocked_data
  349. y_test = y_train
  350. else:
  351. train_blocked_data = test_blocked_data = blocked_data
  352. y_train = y_test = y
  353. test_indices = np.arange(y.shape[0])
  354. return train_blocked_data, test_blocked_data, y_train, y_test, test_indices
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...