Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

util.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. from collections import Counter
  2. from typing import List
  3. from imodels.util.rule import Rule
  4. def extract_ensemble(weak_learners, X, y, min_multiplicity: int = 1) -> List[Rule]:
  5. all_rules = []
  6. all_subterms = []
  7. for est in weak_learners:
  8. est.fit(X, y)
  9. all_rules += est.rules_
  10. all_est_subterms = set([indv_r for r in est.rules_ for indv_r in split(r)])
  11. all_subterms += all_est_subterms
  12. if min_multiplicity > 0:
  13. # round rule decision boundaries to increase matching
  14. for i in range(len(all_rules)):
  15. for key in all_rules[i].agg_dict:
  16. all_rules[i].agg_dict[key] = round(float(all_rules[i].agg_dict[key]), 1)
  17. # match full_rules
  18. repeated_full_rules_counter = {k: v for k, v in Counter(all_rules).items() if v > min_multiplicity}
  19. repeated_rules = set(repeated_full_rules_counter.keys())
  20. # match subterms of rules
  21. repeated_subterm_counter = {k: v for k, v in Counter(all_subterms).items() if v > min_multiplicity}
  22. repeated_rules = repeated_rules.union(set(repeated_subterm_counter.keys()))
  23. # convert to str form to be rescored
  24. repeated_rules = list(map(str, repeated_rules))
  25. return repeated_rules
  26. def split(rule: Rule) -> List[Rule]:
  27. if len(rule.agg_dict) == 1:
  28. return [rule]
  29. else:
  30. indv_rule_strs = list(map(lambda x: ' '.join(x), rule.terms))
  31. indv_rules = list(map(lambda x: Rule(x), indv_rule_strs))
  32. return indv_rules
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...