Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

grl_test.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  1. import unittest
  2. import numpy as np
  3. from imodels.rule_list.greedy_rule_list import GreedyRuleListClassifier
  4. import sklearn
  5. from sklearn.model_selection import train_test_split
  6. class TestGRL(unittest.TestCase):
  7. @classmethod
  8. def setUpClass(cls):
  9. cls.m = GreedyRuleListClassifier()
  10. def test_integration_stability(self):
  11. '''Test on synthetic dataset
  12. '''
  13. X = np.array(
  14. [[0, 0, 1, 1, 0],
  15. [1, 0, 0, 0, 0],
  16. [0, 0, 1, 0, 0],
  17. [1, 0, 0, 0, 0],
  18. [1, 1, 0, 1, 1],
  19. [1, 1, 1, 1, 1],
  20. [0, 1, 1, 1, 1],
  21. [1, 0, 1, 1, 1]])
  22. y = np.array([0, 0, 0, 0, 1, 1, 1, 1])
  23. self.m.fit(X, y)
  24. yhat = self.m.predict(X)
  25. acc = np.mean(y == yhat) * 100
  26. assert acc > 99 # acc must be 100
  27. def test_linear_separability(self):
  28. """Test if the model can learn a linearly separable dataset"""
  29. x = np.array([0.8, 0.8, 0.3, 0.3, 0.3, 0.3]).reshape(-1, 1)
  30. y = np.array([0, 0, 1, 1, 1, 1])
  31. self.m.fit(x, y, verbose=True)
  32. yhat = self.m.predict(x)
  33. acc = np.mean(y == yhat) * 100
  34. assert len(self.m.rules_) == 2
  35. assert acc == 100 # acc must be 100
  36. def test_y_left_conditional_probability(self):
  37. """Test conditional probability of y given x in the left node"""
  38. x = np.array([0.8, 0.8, 0.3, 0.3, 0.3, 0.3]).reshape(-1, 1)
  39. y = np.array([0, 0, 1, 1, 1, 1])
  40. self.m.fit(x, y, verbose=True)
  41. assert self.m.rules_[1]["val"] == 0
  42. def test_breast_cancer():
  43. np.random.seed(13)
  44. X, Y = sklearn.datasets.load_breast_cancer(as_frame=True, return_X_y=True)
  45. model = GreedyRuleListClassifier(max_depth=10)
  46. X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.3)
  47. model.fit(X_train, y_train, feature_names=X_train.columns)
  48. y_pred = model.predict(X_test)
  49. # score = accuracy_score(y_test.values,y_pred)
  50. # print('Accuracy:', score)
  51. # model._print_list()
  52. if __name__ == '__main__':
  53. test_breast_cancer()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...