Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stableskope.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import numpy as np
  2. from typing import List
  3. from imodels.rule_set.skope_rules import SkopeRulesClassifier
  4. from imodels.util.rule import Rule
  5. from imodels.util.score import score_precision_recall
  6. from sklearn.base import BaseEstimator
  7. from .util import extract_ensemble
  8. class StableSkopeClassifier(SkopeRulesClassifier):
  9. def __init__(self,
  10. weak_learners: List[BaseEstimator],
  11. max_complexity: int,
  12. min_mult: int = 1,
  13. precision_min=0.5,
  14. recall_min=0.4,
  15. n_estimators=10,
  16. max_samples=.8,
  17. max_samples_features=.8,
  18. bootstrap=False,
  19. bootstrap_features=False,
  20. max_depth=3,
  21. max_depth_duplication=None,
  22. max_features=1.,
  23. min_samples_split=2,
  24. n_jobs=1,
  25. random_state=None):
  26. super().__init__(precision_min,
  27. recall_min,
  28. n_estimators,
  29. max_samples,
  30. max_samples_features,
  31. bootstrap,
  32. bootstrap_features,
  33. max_depth,
  34. max_depth_duplication,
  35. max_features,
  36. min_samples_split,
  37. n_jobs,
  38. random_state)
  39. self.weak_learners = weak_learners
  40. self.max_complexity = max_complexity
  41. self.min_mult = min_mult
  42. def fit(self, X, y=None, feature_names=None, sample_weight=None):
  43. super().fit(X, y, feature_names=feature_names, sample_weight=sample_weight)
  44. return self
  45. def _extract_rules(self, X, y) -> List[str]:
  46. return [extract_ensemble(self.weak_learners, X, y, self.min_mult)], [np.arange(X.shape[0])], [np.arange(len(self.feature_names))]
  47. def _score_rules(self, X, y, rules) -> List[Rule]:
  48. return score_precision_recall(X, y,
  49. rules,
  50. self.estimators_samples_,
  51. self.estimators_features_,
  52. self.feature_placeholders,
  53. oob=False)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...