Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

classification_continuous_inputs_test.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. import random
  2. import numpy as np
  3. from imodels import * # noqa: F403
  4. class TestClassClassificationContinuousInputs:
  5. '''Tests simple classification for different models. Note: still doesn't test all the models!
  6. '''
  7. def setup_method(self):
  8. np.random.seed(13)
  9. random.seed(13)
  10. self.n = 40
  11. self.p = 2
  12. self.X_classification_binary = np.random.randn(self.n, self.p)
  13. # y = x0 > 0
  14. self.y_classification_binary = (
  15. self.X_classification_binary[:, 0] > 0).astype(int)
  16. # flip labels for last few
  17. self.y_classification_binary[-2:] = 1 - \
  18. self.y_classification_binary[-2:]
  19. def test_classification_binary(self):
  20. '''Test imodels on basic binary classification task
  21. '''
  22. for model_type in [
  23. BoostedRulesClassifier,
  24. TaoTreeClassifier,
  25. RuleFitClassifier, GreedyRuleListClassifier,
  26. SkopeRulesClassifier,
  27. OneRClassifier, SlipperClassifier,
  28. GreedyTreeClassifier, OptimalTreeClassifier,
  29. C45TreeClassifier, FIGSClassifier,
  30. TreeGAMClassifier,
  31. ]: # IRFClassifier, SLIMClassifier, BayesianRuleSetClassifier,
  32. init_kwargs = {}
  33. if model_type == SkopeRulesClassifier or model_type == FPSkopeClassifier:
  34. init_kwargs['random_state'] = 0
  35. init_kwargs['max_samples_features'] = 1.
  36. elif model_type == SlipperClassifier:
  37. init_kwargs['n_estimators'] = 1
  38. elif model_type == TreeGAMClassifier:
  39. init_kwargs['n_boosting_rounds'] = 10
  40. m = model_type(**init_kwargs)
  41. X = self.X_classification_binary
  42. m.fit(X, self.y_classification_binary)
  43. # test predict()
  44. preds = m.predict(X) # > 0.5).astype(int)
  45. assert preds.size == self.n, 'predict() yields right size'
  46. # test preds_proba()
  47. if model_type not in {OptimalRuleListClassifier, OptimalTreeClassifier}:
  48. preds_proba = m.predict_proba(X)
  49. assert len(preds_proba.shape) == 2, 'preds_proba has 2 columns'
  50. assert preds_proba.shape[1] == 2, 'preds_proba has 2 columns'
  51. assert np.max(
  52. preds_proba) < 1.1, 'preds_proba has no values over 1'
  53. assert (np.argmax(preds_proba, axis=1) == preds).all(), ("predict_proba and "
  54. "predict agree")
  55. # test acc
  56. acc_train = np.mean(preds == self.y_classification_binary)
  57. # print(type(m), m, 'final acc', acc_train)
  58. assert acc_train > 0.8, 'acc greater than 0.8'
  59. if __name__ == '__main__':
  60. t = TestClassClassificationContinuousInputs()
  61. t.setup_method()
  62. t.test_classification_binary()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...