Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

classification_binary_inputs_test.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. import random
  2. import numpy as np
  3. from imodels import (
  4. OptimalRuleListClassifier,
  5. OptimalTreeClassifier,
  6. FPLassoClassifier,
  7. FPSkopeClassifier,
  8. TreeGAMClassifier,
  9. )
  10. class TestClassClassificationBinary:
  11. """Tests simple classification for different models. Note: still doesn't test all the models!"""
  12. def setup_method(self):
  13. np.random.seed(13)
  14. random.seed(13)
  15. self.n = 40
  16. self.p = 2
  17. self.X_classification_binary = (
  18. np.random.randn(self.n, self.p) > 0).astype(int)
  19. # y = x0 > 0
  20. self.y_classification_binary = (self.X_classification_binary[:, 0] > 0).astype(
  21. int
  22. )
  23. # flip labels for last few
  24. self.y_classification_binary[-2:] = 1 - \
  25. self.y_classification_binary[-2:]
  26. def test_classification_binary(self):
  27. """Test imodels on basic binary classification task"""
  28. for model_type in [
  29. OptimalRuleListClassifier,
  30. OptimalTreeClassifier,
  31. FPLassoClassifier,
  32. FPSkopeClassifier,
  33. TreeGAMClassifier,
  34. ]:
  35. init_kwargs = {}
  36. if model_type == FPSkopeClassifier:
  37. init_kwargs["recall_min"] = 0.5
  38. if model_type == TreeGAMClassifier:
  39. init_kwargs["n_boosting_rounds"] = 10
  40. m = model_type(**init_kwargs)
  41. X = self.X_classification_binary
  42. m.fit(X, self.y_classification_binary)
  43. # test predict()
  44. preds = m.predict(X) # > 0.5).astype(int)
  45. assert preds.size == self.n, "predict() yields right size"
  46. # test preds_proba()
  47. if model_type not in {OptimalRuleListClassifier, OptimalTreeClassifier}:
  48. preds_proba = m.predict_proba(X)
  49. assert len(preds_proba.shape) == 2, "preds_proba has 2 columns"
  50. assert preds_proba.shape[1] == 2, "preds_proba has 2 columns"
  51. assert np.max(
  52. preds_proba) < 1.1, "preds_proba has no values over 1"
  53. assert (np.argmax(preds_proba, axis=1) == preds).all(), (
  54. "predict_proba and " "predict correspond"
  55. )
  56. # test acc
  57. acc_train = np.mean(preds == self.y_classification_binary)
  58. # print(type(m), m, 'final acc', acc_train)
  59. assert acc_train > 0.8, "acc greater than 0.8"
  60. if __name__ == "__main__":
  61. t = TestClassClassificationBinary()
  62. t.setup()
  63. t.test_classification_binary()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...