Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

shrinkage_test.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  1. import random
  2. from functools import partial
  3. import numpy as np
  4. from sklearn.ensemble import VotingRegressor, RandomForestClassifier, GradientBoostingClassifier
  5. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  6. from imodels import HSTreeClassifier, HSTreeClassifierCV, \
  7. HSTreeRegressor, HSTreeRegressorCV, C45TreeClassifier
  8. # OptimalTreeClassifier, HSOptimalTreeClassifierCV
  9. from imodels.tree.c45_tree.c45_tree import HSC45TreeClassifierCV
  10. import random
  11. from functools import partial
  12. import numpy as np
  13. from sklearn.ensemble import VotingRegressor
  14. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  15. from imodels import HSTreeClassifier, HSTreeClassifierCV, \
  16. HSTreeRegressor, HSTreeRegressorCV, C45TreeClassifier
  17. # OptimalTreeClassifier, HSOptimalTreeClassifierCV
  18. from imodels.tree.c45_tree.c45_tree import HSC45TreeClassifierCV
  19. class TestShrinkage:
  20. '''Tests simple classification for different models. Note: still doesn't test all the models!
  21. '''
  22. def setup_method(self):
  23. np.random.seed(13)
  24. random.seed(13)
  25. self.n = 20
  26. self.p = 2
  27. self.X_classification_binary = (
  28. np.random.randn(self.n, self.p) > 0).astype(int)
  29. # y = x0 > 0
  30. self.y_classification_binary = (
  31. self.X_classification_binary[:, 0] > 0).astype(int)
  32. # flip labels for last few
  33. self.y_classification_binary[-2:] = 1 - \
  34. self.y_classification_binary[-2:]
  35. self.X_regression = np.random.randn(self.n, self.p)
  36. self.y_regression = self.X_regression[:,
  37. 0] + np.random.randn(self.n) * 0.01
  38. def test_classification_shrinkage(self):
  39. '''Test imodels on basic binary classification task
  40. '''
  41. for model_type in [
  42. partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()),
  43. partial(HSTreeClassifier, estimator_=GradientBoostingClassifier()),
  44. partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()),
  45. partial(HSTreeClassifierCV, estimator_=DecisionTreeClassifier()),
  46. partial(HSTreeClassifierCV, estimator_=RandomForestClassifier()),
  47. partial(HSC45TreeClassifierCV, estimator_=C45TreeClassifier()),
  48. HSTreeClassifierCV, # default estimator is Decision tree with 25 max_leaf_nodes
  49. # partial(HSOptimalTreeClassifierCV, estimator_=OptimalTreeClassifier()),
  50. ]:
  51. init_kwargs = {}
  52. m = model_type(**init_kwargs)
  53. X = self.X_classification_binary
  54. m.fit(X, self.y_classification_binary)
  55. # test predict()
  56. preds = m.predict(X) # > 0.5).astype(int)
  57. assert preds.size == self.n, 'predict() yields right size'
  58. # test preds_proba()
  59. preds_proba = m.predict_proba(X)
  60. assert len(preds_proba.shape) == 2, 'preds_proba has 2 columns'
  61. assert preds_proba.shape[1] == 2, 'preds_proba has 2 columns'
  62. assert np.max(
  63. preds_proba) < 1.1, 'preds_proba has no values over 1'
  64. assert (np.argmax(preds_proba, axis=1) == preds).all(
  65. ), ("predict_proba and ""predict correspond")
  66. # test acc
  67. acc_train = np.mean(preds == self.y_classification_binary)
  68. # print(type(m), m, 'final acc', acc_train)
  69. assert acc_train > 0.8, 'acc greater than 0.8'
  70. # complexity
  71. assert m.complexity_ > 0, 'complexity is greater than 0'
  72. def test_recognized_by_sklearn(self):
  73. base_models = [('hs', HSTreeRegressor(DecisionTreeRegressor())),
  74. ('dt', DecisionTreeRegressor())]
  75. comb_model = VotingRegressor(estimators=base_models,
  76. n_jobs=10,
  77. verbose=2)
  78. comb_model.fit(self.X_classification_binary, self.y_regression)
  79. def test_regression_shrinkage(self):
  80. '''Test imodels on basic binary classification task
  81. '''
  82. for model_type in [partial(HSTreeRegressor, estimator_=DecisionTreeRegressor()),
  83. partial(HSTreeRegressorCV,
  84. estimator_=DecisionTreeRegressor()),
  85. ]:
  86. m = model_type()
  87. m.fit(self.X_regression, self.y_regression)
  88. preds = m.predict(self.X_regression)
  89. assert preds.size == self.n, 'predictions are right size'
  90. mse = np.mean(np.square(preds - self.y_regression))
  91. assert mse < 1, 'mse less than 1'
  92. # complexity
  93. assert m.complexity_ > 0, 'complexity is greater than 0'
  94. if __name__ == '__main__':
  95. t = TestShrinkage()
  96. t.setup_method()
  97. t.test_classification_shrinkage()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...