Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

boosted_rules.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. from copy import deepcopy
  2. from functools import partial
  3. import numpy as np
  4. import sklearn
  5. from sklearn.ensemble import AdaBoostClassifier, AdaBoostRegressor
  6. from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin
  7. from sklearn.model_selection import train_test_split
  8. from sklearn.preprocessing import normalize
  9. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  10. from sklearn.utils.multiclass import check_classification_targets, unique_labels
  11. from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
  12. from imodels.rule_set.rule_set import RuleSet
  13. from imodels.rule_set.slipper_util import SlipperBaseEstimator
  14. from imodels.util.arguments import check_fit_arguments
  15. from imodels.util.convert import tree_to_code, tree_to_rules, dict_to_rule
  16. from imodels.util.rule import Rule, get_feature_dict, replace_feature_name
  17. class BoostedRulesClassifier(AdaBoostClassifier):
  18. '''An easy-interpretable classifier optimizing simple logical rules.
  19. Params
  20. ------
  21. estimator: object with fit and predict methods
  22. Defaults to DecisionTreeClassifier with AdaBoost.
  23. For SLIPPER, should pass estimator=imodels.SlipperBaseEstimator
  24. '''
  25. def __init__(
  26. self,
  27. estimator=DecisionTreeClassifier(max_depth=1),
  28. *,
  29. n_estimators=15,
  30. learning_rate=1.0,
  31. random_state=None,
  32. ):
  33. try: # sklearn version >= 1.2
  34. super().__init__(
  35. estimator=estimator,
  36. n_estimators=n_estimators,
  37. learning_rate=learning_rate,
  38. random_state=random_state,
  39. )
  40. except: # sklearn version < 1.2
  41. super().__init__(
  42. base_estimator=estimator,
  43. n_estimators=n_estimators,
  44. learning_rate=learning_rate,
  45. random_state=random_state,
  46. )
  47. self.estimator = estimator
  48. def fit(self, X, y, feature_names=None, **kwargs):
  49. X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
  50. super().fit(X, y, **kwargs)
  51. self.complexity_ = len(self.estimators_)
  52. class BoostedRulesRegressor(AdaBoostRegressor):
  53. '''An easy-interpretable regressor optimizing simple logical rules.
  54. Params
  55. ------
  56. estimator: object with fit and predict methods
  57. Defaults to DecisionTreeRegressor with AdaBoost.
  58. '''
  59. def __init__(
  60. self,
  61. estimator=DecisionTreeRegressor(max_depth=1),
  62. *,
  63. n_estimators=15,
  64. learning_rate=1.0,
  65. random_state=13,
  66. ):
  67. try: # sklearn version >= 1.2
  68. super().__init__(
  69. estimator=estimator,
  70. n_estimators=n_estimators,
  71. learning_rate=learning_rate,
  72. random_state=random_state,
  73. )
  74. except: # sklearn version < 1.2
  75. super().__init__(
  76. base_estimator=estimator,
  77. n_estimators=n_estimators,
  78. learning_rate=learning_rate,
  79. random_state=random_state,
  80. )
  81. self.estimator = estimator
  82. def fit(self, X, y, feature_names=None, **kwargs):
  83. X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
  84. super().fit(X, y, **kwargs)
  85. self.complexity_ = len(self.estimators_)
  86. if __name__ == '__main__':
  87. np.random.seed(13)
  88. X, Y = sklearn.datasets.load_breast_cancer(as_frame=True, return_X_y=True)
  89. model = BoostedRulesClassifier(estimator=DecisionTreeClassifier)
  90. X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.3)
  91. model.fit(X_train, y_train, feature_names=X_train.columns)
  92. y_pred = model.predict(X_test)
  93. acc = model.score(X_test, y_test)
  94. print('acc', acc, 'complexity', model.complexity_)
  95. print(model)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...