Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

greedy_rule_list.py 12 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
  1. '''Greedy rule list.
  2. Greedily splits on one feature at a time along a single path.
  3. Tries to find rules which maximize the probability of class 1.
  4. Currently only supports binary classification.
  5. '''
  6. import math
  7. from copy import deepcopy
  8. import numpy as np
  9. from sklearn.base import BaseEstimator, ClassifierMixin
  10. from sklearn.utils.multiclass import unique_labels
  11. from sklearn.utils.validation import check_array, check_is_fitted
  12. from sklearn.tree import DecisionTreeClassifier
  13. from imodels.rule_list.rule_list import RuleList
  14. from imodels.util.arguments import check_fit_arguments
  15. class GreedyRuleListClassifier(BaseEstimator, RuleList, ClassifierMixin):
  16. def __init__(self, max_depth: int = 5, class_weight=None,
  17. criterion: str = 'gini'):
  18. '''
  19. Params
  20. ------
  21. max_depth
  22. Maximum depth the list can achieve
  23. criterion: str
  24. Criterion used to split
  25. 'gini', 'entropy', or 'log_loss'
  26. '''
  27. self.max_depth = max_depth
  28. self.class_weight = class_weight
  29. self.criterion = criterion
  30. self.depth = 0 # tracks the fitted depth
  31. def fit(self, X, y, depth: int = 0, feature_names=None, verbose=False):
  32. """
  33. Params
  34. ------
  35. X: array_like
  36. Feature set
  37. y: array_like
  38. target variable
  39. depth
  40. the depth of the current layer (used to recurse)
  41. """
  42. X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
  43. return self.fit_node_recursive(X, y, depth=0, verbose=verbose)
  44. def fit_node_recursive(self, X, y, depth: int, verbose):
  45. # base case 1: no data in this group
  46. if y.size == 0:
  47. return []
  48. # base case 2: all y is the same in this group
  49. elif np.all(y == y[0]):
  50. return [{'val': y[0], 'num_pts': y.size}]
  51. # base case 3: max depth reached
  52. elif depth == self.max_depth:
  53. return [{'val': np.mean(y), 'num_pts': y.size}]
  54. # recursively generate rule list
  55. else:
  56. # find a split with the best value for the criterion
  57. m = DecisionTreeClassifier(max_depth=1, criterion=self.criterion)
  58. m.fit(X, y)
  59. col = m.tree_.feature[0]
  60. cutoff = m.tree_.threshold[0]
  61. # col, cutoff, criterion_val = self._find_best_split(X, y)
  62. if col == -2:
  63. return []
  64. y_left = y[X[:, col] < cutoff] # left-hand side data
  65. y_right = y[X[:, col] >= cutoff] # right-hand side data
  66. # put higher probability of class 1 on the right-hand side
  67. if len(y_left) > 0 and np.mean(y_left) > np.mean(y_right):
  68. flip = True
  69. tmp = deepcopy(y_left)
  70. y_left = deepcopy(y_right)
  71. y_right = tmp
  72. x_left = X[X[:, col] >= cutoff]
  73. else:
  74. flip = False
  75. x_left = X[X[:, col] < cutoff]
  76. # print
  77. if verbose:
  78. print(
  79. f'{np.mean(100 * y):.2f} -> {self.feature_names_[col]} -> {np.mean(100 * y_left):.2f} ({y_left.size}) {np.mean(100 * y_right):.2f} ({y_right.size})')
  80. # save info
  81. par_node = [{
  82. 'col': self.feature_names_[col],
  83. 'index_col': col,
  84. 'cutoff': cutoff,
  85. 'val': np.mean(y_left), # will be the values before splitting in the next lower level
  86. 'flip': flip,
  87. 'val_right': np.mean(y_right),
  88. 'num_pts': y.size,
  89. 'num_pts_right': y_right.size
  90. }]
  91. # generate tree for the non-leaf data
  92. par_node = par_node + \
  93. self.fit_node_recursive(x_left, y_left, depth + 1, verbose=verbose)
  94. self.depth += 1 # increase the depth since we call fit once
  95. self.rules_ = par_node
  96. self.complexity_ = len(self.rules_)
  97. self.classes_ = unique_labels(y)
  98. return par_node
  99. def predict_proba(self, X):
  100. check_is_fitted(self)
  101. X = check_array(X)
  102. n = X.shape[0]
  103. probs = np.zeros(n)
  104. for i in range(n):
  105. x = X[i]
  106. for j, rule in enumerate(self.rules_):
  107. if j == len(self.rules_) - 1:
  108. probs[i] = rule['val']
  109. continue
  110. regular_condition = x[rule["index_col"]] >= rule["cutoff"]
  111. flipped_condition = x[rule["index_col"]] < rule["cutoff"]
  112. condition = flipped_condition if rule["flip"] else regular_condition
  113. if condition:
  114. probs[i] = rule['val_right']
  115. break
  116. return np.vstack((1 - probs, probs)).transpose() # probs (n, 2)
  117. def predict(self, X):
  118. check_is_fitted(self)
  119. X = check_array(X)
  120. return np.argmax(self.predict_proba(X), axis=1)
  121. """
  122. def __str__(self):
  123. # s = ''
  124. # for rule in self.rules_:
  125. # s += f"mean {rule['val'].round(3)} ({rule['num_pts']} pts)\n"
  126. # if 'col' in rule:
  127. # s += f"if {rule['col']} >= {rule['cutoff']} then {rule['val_right'].round(3)} ({rule['num_pts_right']} pts)\n"
  128. # return s
  129. """
  130. def __str__(self):
  131. '''Print out the list in a nice way
  132. '''
  133. s = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'
  134. def red(s):
  135. # return f"\033[91m{s}\033[00m"
  136. return s
  137. def cyan(s):
  138. # return f"\033[96m{s}\033[00m"
  139. return s
  140. def rule_name(rule):
  141. if rule['flip']:
  142. return '~' + rule['col']
  143. return rule['col']
  144. # rule = self.rules_[0]
  145. # s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
  146. for rule in self.rules_:
  147. s += u'\u2193\n' + f"{cyan((100 * rule['val']).round(2))}% risk ({rule['num_pts']} pts)\n"
  148. # s += f"\t{'Else':>45} => {cyan((100 * rule['val']).round(2)):>6}% IwI ({rule['val'] * rule['num_pts']:.0f}/{rule['num_pts']} pts)\n"
  149. if 'col' in rule:
  150. # prefix = f"if {rule['col']} >= {rule['cutoff']}"
  151. prefix = f"if {rule_name(rule)}"
  152. val = f"{100 * rule['val_right'].round(3)}"
  153. s += f"\t{prefix} ==> {red(val)}% risk ({rule['num_pts_right']} pts)\n"
  154. # rule = self.rules_[-1]
  155. # s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
  156. return s
  157. ######## HERE ONWARDS CUSTOM SPLITTING (DEPRECATED IN FAVOR OF SKLEARN STUMP) ########
  158. ######################################################################################
  159. def _find_best_split(self, x, y):
  160. """
  161. Find the best split from all features
  162. returns: the column to split on, the cutoff value, and the actual criterion_value
  163. """
  164. col = None
  165. min_criterion_val = 1e10
  166. cutoff = None
  167. # iterating through each feature
  168. for i, c in enumerate(x.T):
  169. # find the best split of that feature
  170. criterion_val, cur_cutoff = self._split_on_feature(c, y)
  171. # found perfect cutoff
  172. if criterion_val == 0:
  173. return i, cur_cutoff, criterion_val
  174. # check if it's best so far
  175. elif criterion_val <= min_criterion_val:
  176. min_criterion_val = criterion_val
  177. col = i
  178. cutoff = cur_cutoff
  179. return col, cutoff, min_criterion_val
  180. def _split_on_feature(self, col, y):
  181. """
  182. col: the column we split on
  183. y: target var
  184. """
  185. min_criterion_val = 1e10
  186. cutoff = 0.5
  187. # iterate through each value in the column
  188. for value in np.unique(col):
  189. # separate y into 2 groups
  190. y_predict = col < value
  191. # get criterion val of this split
  192. criterion_val = self._weighted_criterion(y_predict, y)
  193. # check if it's the smallest one so far
  194. if criterion_val <= min_criterion_val:
  195. min_criterion_val = criterion_val
  196. cutoff = value
  197. return min_criterion_val, cutoff
  198. def _weighted_criterion(self, split_decision, y_real):
  199. """Returns criterion calculated over a split
  200. split decision, True/False, and y_true can be multi class
  201. """
  202. if split_decision.shape[0] != y_real.shape[0]:
  203. print('They have to be the same length')
  204. return None
  205. # choose the splitting criterion
  206. if self.criterion == 'entropy':
  207. criterion_func = self._entropy_criterion
  208. elif self.criterion == 'gini':
  209. criterion_func = self._gini_criterion
  210. elif self.criterion == 'neg_corr':
  211. return self._neg_corr_criterion(split_decision, y_real)
  212. # left-hand side criterion
  213. s_left = criterion_func(y_real[split_decision])
  214. # right-hand side criterion
  215. s_right = criterion_func(y_real[~split_decision])
  216. # overall criterion, again weighted average
  217. n = y_real.shape[0]
  218. if self.class_weight is not None:
  219. sample_weights = np.ones(n)
  220. for c in self.class_weight.keys():
  221. idxs_c = y_real == c
  222. sample_weights[idxs_c] = self.class_weight[c]
  223. total_weight = np.sum(sample_weights)
  224. weight_left = np.sum(sample_weights[split_decision]) / total_weight
  225. # weight_right = np.sum(sample_weights[~split_decision]) / total_weight
  226. else:
  227. tot_left_samples = np.sum(split_decision == 1)
  228. weight_left = tot_left_samples / n
  229. s = weight_left * s_left + (1 - weight_left) * s_right
  230. return s
  231. def _gini_criterion(self, y):
  232. '''Returns gini index for one node
  233. = sum(pc * (1 – pc))
  234. '''
  235. s = 0
  236. n = y.shape[0]
  237. classes = np.unique(y)
  238. # for each class, get entropy
  239. for c in classes:
  240. # weights for each class
  241. n_c = np.sum(y == c)
  242. p_c = n_c / n
  243. # weighted avg
  244. s += p_c * (1 - p_c)
  245. return s
  246. def _entropy_criterion(self, y):
  247. """Returns entropy of a divided group of data
  248. Data may have multiple classes
  249. """
  250. s = 0
  251. n = len(y)
  252. classes = set(y)
  253. # for each class, get entropy
  254. for c in classes:
  255. # weights for each class
  256. weight = sum(y == c) / n
  257. def _entropy_from_counts(c1, c2):
  258. """Returns entropy of a group of data
  259. c1: count of one class
  260. c2: count of another class
  261. """
  262. if c1 == 0 or c2 == 0: # when there is only one class in the group, entropy is 0
  263. return 0
  264. def _entropy_func(p): return -p * math.log(p, 2)
  265. p1 = c1 * 1.0 / (c1 + c2)
  266. p2 = c2 * 1.0 / (c1 + c2)
  267. return _entropy_func(p1) + _entropy_func(p2)
  268. # weighted avg
  269. s += weight * _entropy_from_counts(sum(y == c), sum(y != c))
  270. return s
  271. def _neg_corr_criterion(self, split_decision, y):
  272. '''Returns negative correlation between y
  273. and the binary splitting variable split_decision
  274. y must be binary
  275. '''
  276. if np.unique(y).size < 2:
  277. return 0
  278. elif np.unique(y).size != 2:
  279. print('y must be binary output for corr criterion')
  280. # y should be 1 more often on the "right side" of the split
  281. if y.sum() < y.size / 2:
  282. y = 1 - y
  283. return -1 * np.corrcoef(split_decision.astype(np.int), y)[0, 1]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...