Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

corels_wrapper.py 11 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
  1. # This is just a simple wrapper around pycorels: https://github.com/corels/pycorels
  2. import warnings
  3. from typing import List
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.preprocessing import KBinsDiscretizer
  7. from imodels.rule_list.greedy_rule_list import GreedyRuleListClassifier
  8. corels_supported = False
  9. try:
  10. from corels import CorelsClassifier
  11. corels_supported = True
  12. except:
  13. pass
  14. class OptimalRuleListClassifier(GreedyRuleListClassifier if not corels_supported else CorelsClassifier):
  15. """Certifiably Optimal RulE ListS classifier.
  16. This class implements the CORELS algorithm, designed to produce human-interpretable, optimal
  17. rulelists for binary feature data and binary classification. As an alternative to other
  18. tree based algorithms such as CART, CORELS provides a certificate of optimality for its
  19. rulelist given a training set, leveraging multiple algorithmic bounds to do so.
  20. In order to use run the algorithm, create an instance of the `CorelsClassifier` class,
  21. providing any necessary parameters in its constructor, and then call `fit` to generate
  22. a rulelist. `printrl` prints the generated rulelist, while `predict` provides
  23. classification predictions for a separate test dataset with the same features. To determine
  24. the algorithm's accuracy, run `score` on an evaluation dataset with labels.
  25. To save a generated rulelist to a file, call `save`. To load it back from the file, call `load`.
  26. Attributes
  27. ----------
  28. c : float, optional (default=0.01)
  29. Regularization parameter. Higher values penalize longer rulelists.
  30. n_iter : int, optional (default=10000)
  31. Maximum number of nodes (rulelists) to search before exiting.
  32. map_type : str, optional (default="prefix")
  33. The type of prefix map to use. Supported maps are "none" for no map,
  34. "prefix" for a map that uses rule prefixes for keys, "captured" for
  35. a map with a prefix's captured vector as keys.
  36. policy : str, optional (default="lower_bound")
  37. The search policy for traversing the tree (i.e. the criterion with which
  38. to order nodes in the queue). Supported criteria are "bfs", for breadth-first
  39. search; "curious", which attempts to find the most promising node;
  40. "lower_bound" which is the objective function evaluated with that rulelist
  41. minus the default prediction error; "objective" for the objective function
  42. evaluated at that rulelist; and "dfs" for depth-first search.
  43. verbosity : list, optional (default=["rulelist"])
  44. The verbosity levels required. A list of strings, it can contain any
  45. subset of ["rulelist", "rule", "label", "minor", "samples", "progress", "mine", "loud"].
  46. An empty list ([]) indicates 'silent' mode.
  47. - "rulelist" prints the generated rulelist at the end.
  48. - "rule" prints a summary of each rule generated.
  49. - "label" prints a summary of the class labels.
  50. - "minor" prints a summary of the minority bound.
  51. - "samples" produces a complete dump of the rules, label, and/or minor data. You must also provide at least one of "rule", "label", or "minor" to specify which data you want to dump, or "loud" for all data. The "samples" option often spits out a lot of output.
  52. - "progress" prints periodic messages as corels runs.
  53. - "mine" prints debug information while mining rules, including each rule as it is generated.
  54. - "loud" is the equivalent of ["progress", "label", "rule", "mine", "minor"].
  55. ablation : int, optional (default=0)
  56. Specifies addition parameters for the bounds used while searching. Accepted
  57. values are 0 (all bounds), 1 (no antecedent support bound), and 2 (no
  58. lookahead bound).
  59. max_card : int, optional (default=2)
  60. Maximum cardinality allowed when mining rules. Can be any value greater than
  61. or equal to 1. For instance, a value of 2 would only allow rules that combine
  62. at most two features in their antecedents.
  63. min_support : float, optional (default=0.01)
  64. The fraction of samples that a rule must capture in order to be used. 1 minus
  65. this value is also the maximum fraction of samples a rule can capture.
  66. Can be any value between 0.0 and 0.5.
  67. References
  68. ----------
  69. Elaine Angelino, Nicholas Larus-Stone, Daniel Alabi, Margo Seltzer, and Cynthia Rudin.
  70. Learning Certifiably Optimal Rule Lists for Categorical Data. KDD 2017.
  71. Journal of Machine Learning Research, 2018; 19: 1-77. arXiv:1704.01701, 2017
  72. Examples
  73. --------
  74. """
  75. def __init__(self, c=0.01, n_iter=10000, map_type="prefix", policy="lower_bound",
  76. verbosity=[], ablation=0, max_card=2, min_support=0.01, random_state=0):
  77. if corels_supported:
  78. super().__init__(c, n_iter, map_type, policy, verbosity, ablation, max_card, min_support)
  79. else:
  80. warnings.warn("Should install corels with pip install corels. Using GreedyRuleList instead.")
  81. super().__init__()
  82. self.fit = super().fit
  83. self.predict = super().predict
  84. self.predict_proba = super().predict_proba
  85. self.__str__ = super().__str__
  86. self.random_state = random_state
  87. self.discretizer = None
  88. self.str_print = None
  89. self._estimator_type = 'classifier'
  90. def fit(self, X, y, feature_names=None, prediction_name="prediction"):
  91. """
  92. Build a CORELS classifier from the training set (X, y).
  93. Parameters
  94. ----------
  95. X : array-like, shape = [n_samples, n_features]
  96. The training input samples. All features must be binary, and the matrix
  97. is internally converted to dtype=np.uint8.
  98. y : array-line, shape = [n_samples]
  99. The target values for the training input. Must be binary.
  100. feature_names : list, optional(default=None)
  101. A list of strings of length n_features. Specifies the names of each
  102. of the features. If an empty list is provided, the feature names
  103. are set to the default of ["feature1", "feature2"... ].
  104. prediction_name : string, optional(default="prediction")
  105. The name of the feature that is being predicted.
  106. Returns
  107. -------
  108. self : obj
  109. """
  110. if isinstance(X, pd.DataFrame):
  111. if feature_names is None:
  112. feature_names = X.columns.tolist()
  113. X = X.values
  114. elif feature_names is None:
  115. feature_names = ['X_' + str(i) for i in range(X.shape[1])]
  116. # check if any non-binary values
  117. if not np.isin(X, [0, 1]).all().all():
  118. self.discretizer = KBinsDiscretizer(encode='onehot-dense')
  119. self.discretizer.fit(X, y)
  120. """
  121. feature_names = [f'{col}_{b}'
  122. for col, bins in zip(feature_names, self.discretizer.n_bins_)
  123. for b in range(bins)]
  124. """
  125. feature_names = self.discretizer.get_feature_names_out()
  126. X = self.discretizer.transform(X)
  127. np.random.seed(self.random_state)
  128. # feature_names = feature_names.tolist()
  129. super().fit(X, y, features=feature_names, prediction_name=prediction_name)
  130. # try:
  131. self._traverse_rule(X, y, feature_names)
  132. # except:
  133. # self.str_print = None
  134. self.complexity_ = self._get_complexity()
  135. return self
  136. def predict(self, X):
  137. """
  138. Predict classifications of the input samples X.
  139. Arguments
  140. ---------
  141. X : array-like, shape = [n_samples, n_features]
  142. The training input samples. All features must be binary, and the matrix
  143. is internally converted to dtype=np.uint8. The features must be the same
  144. as those of the data used to train the model.
  145. Returns
  146. -------
  147. p : array[int] of shape = [n_samples].
  148. The classifications of the input samples.
  149. """
  150. if self.discretizer is not None:
  151. X = self.discretizer.transform(X)
  152. return super().predict(X).astype(int)
  153. def predict_proba(self, X):
  154. """
  155. Predict probabilities of the input samples X.
  156. todo: actually calculate these from training set
  157. Arguments
  158. ---------
  159. X : array-like, shape = [n_samples, n_features]
  160. The training input samples. All features must be binary, and the matrix
  161. is internally converted to dtype=np.uint8. The features must be the same
  162. as those of the data used to train the model.
  163. Returns
  164. -------
  165. p : array[float] of shape = [n_samples, 2].
  166. The probabilities of the input samples.
  167. """
  168. preds = self.predict(X)
  169. return np.vstack((1 - preds, preds)).transpose()
  170. def _traverse_rule(self, X: np.ndarray, y: np.ndarray, feature_names: List[str], print_colors=False):
  171. """Traverse rule and build up string representation
  172. Parameters
  173. ----------
  174. df_features
  175. Returns
  176. -------
  177. """
  178. str_print = f''
  179. df = pd.DataFrame(X, columns=feature_names)
  180. df.loc[:, 'y'] = y
  181. o = 'y'
  182. str_print += f' {df[o].sum()} / {df.shape[0]} (positive class / total)\n'
  183. if print_colors:
  184. color_start = '\033[96m'
  185. color_end = '\033[00m'
  186. else:
  187. color_start = ''
  188. color_end = ''
  189. if len(self.rl_.rules) > 1:
  190. str_print += f'\t\u2193 \n'
  191. else:
  192. str_print += ' No rules learned\n'
  193. for j, rule in enumerate(self.rl_.rules[:-1]):
  194. antecedents = rule['antecedents']
  195. query = ''
  196. for i, feat_idx in enumerate(antecedents):
  197. if i > 0:
  198. query += ' & '
  199. if feat_idx < 0:
  200. query += f'(`{feature_names[-feat_idx - 1]}` == 0)'
  201. else:
  202. query += f'(`{feature_names[feat_idx - 1]}` == 1)'
  203. df_rhs = df.query(query)
  204. idxs_satisfying_rule = df_rhs.index
  205. df.drop(index=idxs_satisfying_rule, inplace=True)
  206. computed_prob = 100 * df_rhs[o].sum() / (df_rhs.shape[0] + 1e-10)
  207. # add to str_print
  208. query_print = query.replace('== 1', '').replace('(', '').replace(')', '').replace('`', '')
  209. str_print += f'{color_start}If {query_print:<35}{color_end} \u2192 {df_rhs[o].sum():>3} / {df_rhs.shape[0]:>4} ({computed_prob:0.1f}%)\n\t\u2193 \n {df[o].sum():>3} / {df.shape[0]:>5}\t \n'
  210. if not (j == len(self.rl_.rules) - 2 and i == len(antecedents) - 1):
  211. str_print += '\t\u2193 \n'
  212. self.str_print = str_print
  213. def __str__(self):
  214. if corels_supported:
  215. if self.str_print is not None:
  216. return 'OptimalRuleList:\n\n' + self.str_print
  217. else:
  218. return 'OptimalRuleList:\n\n' + self.rl_.__str__()
  219. else:
  220. return super().__str__()
  221. def _get_complexity(self):
  222. return sum([len(corule['antecedents']) for corule in self.rl_.rules])
  223. if __name__ == '__main__':
  224. X = (np.random.randn(40, 2) > 0).astype(int)
  225. y = (X[:, 0] > 0).astype(int)
  226. y[-2:] = 1 - y[-2:]
  227. m = OptimalRuleListClassifier()
  228. m.fit(X, y)
  229. print(str(m))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...