Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stablelinear.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  1. import numpy as np
  2. from typing import List
  3. from imodels.rule_set.rule_fit import RuleFit
  4. from imodels.util.score import score_linear
  5. from sklearn.base import ClassifierMixin, RegressorMixin, BaseEstimator
  6. from .util import extract_ensemble
  7. class StableLinear(RuleFit):
  8. def __init__(self,
  9. weak_learners: List[BaseEstimator],
  10. max_complexity: int,
  11. min_mult: int = 1,
  12. penalty='l1',
  13. n_estimators=100,
  14. tree_size=4,
  15. sample_fract='default',
  16. max_rules=30,
  17. memory_par=0.01,
  18. tree_generator=None,
  19. lin_trim_quantile=0.025,
  20. lin_standardise=True,
  21. exp_rand_tree_size=True,
  22. include_linear=False,
  23. alpha=None,
  24. random_state=None):
  25. super().__init__(n_estimators,
  26. tree_size,
  27. sample_fract,
  28. max_rules,
  29. memory_par,
  30. tree_generator,
  31. lin_trim_quantile,
  32. lin_standardise,
  33. exp_rand_tree_size,
  34. include_linear,
  35. alpha,
  36. random_state)
  37. self.max_complexity = max_complexity
  38. self.weak_learners = weak_learners
  39. self.penalty = penalty
  40. self.min_mult = min_mult
  41. def fit(self, X, y=None, feature_names=None):
  42. super().fit(X, y, feature_names=feature_names)
  43. return self
  44. def _extract_rules(self, X, y) -> List[str]:
  45. return extract_ensemble(self.weak_learners, X, y, self.min_mult)
  46. def _score_rules(self, X, y, rules):
  47. X_concat = np.zeros([X.shape[0], 0])
  48. # standardise linear variables if requested (for regression model only)
  49. if self.include_linear:
  50. # standard deviation and mean of winsorized features
  51. self.winsorizer.train(X)
  52. winsorized_X = self.winsorizer.trim(X)
  53. self.stddev = np.std(winsorized_X, axis=0)
  54. self.mean = np.mean(winsorized_X, axis=0)
  55. if self.lin_standardise:
  56. self.friedscale.train(X)
  57. X_regn = self.friedscale.scale(X)
  58. else:
  59. X_regn = X.copy()
  60. X_concat = np.concatenate((X_concat, X_regn), axis=1)
  61. X_rules = self.transform(X, rules)
  62. if X_rules.shape[0] > 0:
  63. X_concat = np.concatenate((X_concat, X_rules), axis=1)
  64. # no rules fit and self.include_linear == False
  65. if X_concat.shape[1] == 0:
  66. return [], [], 0
  67. return score_linear(X_concat, y, rules,
  68. alpha=self.alpha,
  69. penalty=self.penalty,
  70. prediction_task=self.prediction_task,
  71. max_rules=self.max_rules, random_state=self.random_state)
  72. class StableLinearRegressor(StableLinear, RegressorMixin):
  73. def _init_prediction_task(self):
  74. self.prediction_task = 'regression'
  75. class StableLinearClassifier(StableLinear, ClassifierMixin):
  76. def _init_prediction_task(self):
  77. self.prediction_task = 'classification'
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...