Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

figs_test.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  1. import os
  2. import random
  3. from functools import partial
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.tree import DecisionTreeRegressor
  7. from imodels import FIGSClassifier, FIGSRegressor, FIGSClassifierCV, FIGSRegressorCV
  8. from imodels.experimental.figs_ensembles import FIGSExtRegressor, FIGSExtClassifier
  9. from sklearn.ensemble import StackingRegressor, VotingRegressor, BaggingClassifier
  10. path_to_tests = os.path.dirname(os.path.realpath(__file__))
  11. class TestFIGS:
  12. def setup_method(self):
  13. '''Test on synthetic dataset
  14. '''
  15. np.random.seed(13)
  16. random.seed(13)
  17. self.n = 100
  18. self.p = 2
  19. self.X = (np.random.randn(self.n, self.p) > 0).astype(int)
  20. # y = x0 > 0 * x1 > 0
  21. self.y_classification_binary = (self.X[:, 0] > 0).astype(int) * (
  22. self.X[:, 1] > 0).astype(int)
  23. self.y_reg = self.X[:, 0] + self.X[:, 1]
  24. def test_recognized_by_sklearn(self):
  25. base_models = [('figs', FIGSRegressor()),
  26. ('random_forest', DecisionTreeRegressor())]
  27. comb_model = VotingRegressor(estimators=base_models,
  28. n_jobs=10,
  29. verbose=2)
  30. comb_model.fit(self.X, self.y_reg)
  31. # def test_categorical(self):
  32. # """Test FIGS with categorical data"""
  33. # categories = ['cat', 'dog', 'bird', 'fish']
  34. # categories_2 = ['bear', 'chicken', 'cow']
  35. # self.X_cat = pd.DataFrame(self.X)
  36. # self.X_cat['pet1'] = np.random.choice(categories, size=(self.n, 1))
  37. # self.X_cat['pet2'] = np.random.choice(categories_2, size=(self.n, 1))
  38. # figs_reg = FIGSRegressor()
  39. # figs_cls = FIGSClassifier()
  40. # figs_reg.fit(self.X_cat, self.y_reg,
  41. # categorical_features=["pet1", 'pet2'])
  42. # figs_reg.predict(self.X_cat, categorical_features=["pet1", 'pet2'])
  43. # figs_cls.fit(self.X_cat, self.y_reg,
  44. # categorical_features=["pet1", 'pet2'])
  45. # figs_cls.predict_proba(
  46. # self.X_cat, categorical_features=["pet1", 'pet2'])
  47. def test_fitting(self):
  48. '''Test on a real (small) dataset
  49. '''
  50. for model_type in [
  51. FIGSClassifier, FIGSRegressor,
  52. FIGSExtClassifier, FIGSExtRegressor,
  53. FIGSClassifierCV, FIGSRegressorCV,
  54. partial(BaggingClassifier,
  55. estimator=FIGSExtClassifier(max_rules=3),
  56. n_estimators=2),
  57. ]:
  58. init_kwargs = {}
  59. m = model_type(**init_kwargs)
  60. X = self.X
  61. m.fit(X, self.y_classification_binary)
  62. # test predict()
  63. preds = m.predict(X) # > 0.5).astype(int)
  64. assert preds.size == self.n, 'predict() yields right size'
  65. # test preds_proba()
  66. if model_type in [FIGSClassifier, FIGSClassifierCV, BaggingClassifier]:
  67. preds_proba = m.predict_proba(X)
  68. assert len(preds_proba.shape) == 2, 'preds_proba has 2 columns'
  69. assert preds_proba.shape[1] == 2, 'preds_proba has 2 columns'
  70. assert np.max(
  71. preds_proba) < 1.1, 'preds_proba has no values over 1'
  72. assert (np.argmax(preds_proba, axis=1) == preds).all(), ("predict_proba and "
  73. "predict correspond")
  74. # test acc
  75. acc_train = np.mean(preds == self.y_classification_binary)
  76. assert acc_train > 0.8, 'acc greater than 0.9'
  77. # print(m)
  78. if not type(m) in [FIGSClassifierCV, FIGSRegressorCV, BaggingClassifier]:
  79. trees = m.trees_
  80. assert len(trees) == 1, 'only one tree'
  81. assert trees[0].feature == 1, 'split on feat 1'
  82. assert np.abs(trees[0].left.value) < 0.01, 'left value 0'
  83. assert trees[0].left.left is None and trees[0].left.right is None, 'left is leaf'
  84. assert np.abs(
  85. trees[0].right.left.value) < 0.01, 'right-left value 0'
  86. assert np.abs(trees[0].right.right.value -
  87. 1) < 0.01, 'right-right value 1'
  88. if __name__ == '__main__':
  89. t = TestFIGS()
  90. t.setup_method()
  91. t.test_fitting()
  92. t.test_categorical()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...