Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cart_wrapper.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  1. # This is just a simple wrapper around sklearn decisiontree
  2. # https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
  3. from sklearn.tree import DecisionTreeClassifier, export_text, DecisionTreeRegressor
  4. from imodels.util.arguments import check_fit_arguments
  5. from imodels.util.tree import compute_tree_complexity
  6. class GreedyTreeClassifier(DecisionTreeClassifier):
  7. """Wrapper around sklearn greedy tree classifier
  8. """
  9. def fit(self, X, y, feature_names=None, sample_weight=None, check_input=True):
  10. """Build a decision tree classifier from the training set (X, y).
  11. Parameters
  12. ----------
  13. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  14. The training input samples. Internally, it will be converted to
  15. ``dtype=np.float32`` and if a sparse matrix is provided
  16. to a sparse ``csc_matrix``.
  17. y : array-like of shape (n_samples,) or (n_samples, n_outputs)
  18. The target values (class labels) as integers or strings.
  19. feature_names : array-like of shape (n_features)
  20. The names of the features
  21. sample_weight : array-like of shape (n_samples,), default=None
  22. Sample weights. If None, then samples are equally weighted. Splits
  23. that would create child nodes with net zero or negative weight are
  24. ignored while searching for a split in each node. Splits are also
  25. ignored if they would result in any single class carrying a
  26. negative weight in either child node.
  27. check_input : bool, default=True
  28. Allow to bypass several input checking.
  29. Don't use this parameter unless you know what you do.
  30. Returns
  31. -------
  32. self : DecisionTreeClassifier
  33. Fitted estimator.
  34. """
  35. X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
  36. super().fit(X, y, sample_weight=sample_weight, check_input=check_input)
  37. self._set_complexity()
  38. def _set_complexity(self):
  39. """Set complexity as number of non-leaf nodes
  40. """
  41. self.complexity_ = compute_tree_complexity(self.tree_)
  42. def __str__(self):
  43. s = '> ------------------------------\n'
  44. s += '> Greedy CART Tree:\n'
  45. s += '> \tPrediction is made by looking at the value in the appropriate leaf of the tree\n'
  46. s += '> ------------------------------' + '\n'
  47. if hasattr(self, 'feature_names') and self.feature_names is not None:
  48. return s + export_text(self, feature_names=self.feature_names, show_weights=True)
  49. else:
  50. return s + export_text(self, show_weights=True)
  51. class GreedyTreeRegressor(DecisionTreeRegressor):
  52. """Wrapper around sklearn greedy tree regressor
  53. """
  54. def fit(self, X, y, feature_names=None, sample_weight=None, check_input=True):
  55. """Build a decision tree regressor from the training set (X, y).
  56. Parameters
  57. ----------
  58. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  59. The training input samples. Internally, it will be converted to
  60. ``dtype=np.float32`` and if a sparse matrix is provided
  61. to a sparse ``csc_matrix``.
  62. y : array-like of shape (n_samples,) or (n_samples, n_outputs)
  63. The target values (real numbers). Use ``dtype=np.float64`` and
  64. ``order='C'`` for maximum efficiency.
  65. sample_weight : array-like of shape (n_samples,), default=None
  66. Sample weights. If None, then samples are equally weighted. Splits
  67. that would create child nodes with net zero or negative weight are
  68. ignored while searching for a split in each node.
  69. check_input : bool, default=True
  70. Allow to bypass several input checking.
  71. Don't use this parameter unless you know what you do.
  72. Returns
  73. -------
  74. self : DecisionTreeRegressor
  75. Fitted estimator.
  76. """
  77. if feature_names is not None:
  78. self.feature_names = feature_names
  79. else:
  80. self.feature_names = ["X" + str(i + 1) for i in range(X.shape[1])]
  81. super().fit(X, y, sample_weight=sample_weight, check_input=check_input)
  82. self._set_complexity()
  83. def _set_complexity(self):
  84. """Set complexity as number of non-leaf nodes
  85. """
  86. self.complexity_ = compute_tree_complexity(self.tree_)
  87. def __str__(self):
  88. if hasattr(self, 'feature_names') and self.feature_names is not None:
  89. return 'GreedyTree:\n' + export_text(self, feature_names=self.feature_names, show_weights=True)
  90. else:
  91. return 'GreedyTree:\n' + export_text(self, show_weights=True)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...