Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

regression_test.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  1. from functools import partial
  2. import numpy as np
  3. import pytest
  4. from sklearn.ensemble import RandomForestRegressor
  5. from sklearn.tree import DecisionTreeRegressor
  6. from imodels import (
  7. RuleFitRegressor,
  8. SLIMRegressor,
  9. GreedyTreeRegressor,
  10. HSTreeRegressor,
  11. HSTreeRegressorCV,
  12. FIGSRegressor,
  13. DistilledRegressor,
  14. TaoTreeRegressor,
  15. BoostedRulesRegressor,
  16. TreeGAMRegressor,
  17. MarginalShrinkageLinearModelRegressor,
  18. )
  19. class TestClassRegression:
  20. def setup_method(self):
  21. np.random.seed(13)
  22. self.n = 10
  23. self.p = 10
  24. self.X_regression = np.random.randn(self.n, self.p)
  25. self.y_regression = self.X_regression[:,
  26. 0] + np.random.randn(self.n) * 0.01
  27. @pytest.mark.filterwarnings("ignore::UserWarning")
  28. def test_regression(self):
  29. """Test imodels on basic binary classification task"""
  30. for model_type in [
  31. RuleFitRegressor,
  32. SLIMRegressor,
  33. GreedyTreeRegressor,
  34. FIGSRegressor, # TaoTreeRegressor,
  35. BoostedRulesRegressor,
  36. partial(
  37. DistilledRegressor,
  38. teacher=RandomForestRegressor(n_estimators=3),
  39. student=DecisionTreeRegressor(),
  40. ),
  41. TreeGAMRegressor,
  42. MarginalShrinkageLinearModelRegressor,
  43. ]:
  44. if model_type == RuleFitRegressor:
  45. m = model_type(include_linear=False, max_rules=3)
  46. else:
  47. m = model_type()
  48. m.fit(self.X_regression, self.y_regression)
  49. preds = m.predict(self.X_regression)
  50. assert preds.size == self.n, "predictions are right size"
  51. mse = np.mean(np.square(preds - self.y_regression))
  52. assert mse < 1, "mse less than 1"
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...