Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

neural_nets.py 6.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
  1. """Bridging random forests and deep neural networks. Code to convert a sklearn decision tree to a pytorch neural network
  2. following "Neural Random Forests" https://arxiv.org/abs/1604.07143
  3. Example
  4. -------
  5. from sklearn.tree import DecisionTreeClassifier
  6. import numpy as np
  7. np.random.seed(13)
  8. num_features = 4
  9. N = 1000
  10. max_depth = 100
  11. # prepare data
  12. X = np.random.rand(N, num_features)
  13. y = np.random.rand(N)
  14. X_t = torch.Tensor(X)
  15. # train rf
  16. dt = DecisionTreeRegressor(max_depth=max_depth)
  17. dt.fit(X, y)
  18. # prepare net
  19. net = Net(dt)
  20. # check if preds are close
  21. preds_dt = dt.predict(X).flatten()
  22. preds_net = net(X_t).detach().numpy().flatten()
  23. assert np.isclose(preds_dt, preds_net).all(), 'preds are not close'
  24. """
  25. import time
  26. from copy import deepcopy
  27. import numpy as np
  28. from torch import nn
  29. class Net(nn.Module):
  30. '''
  31. class which converts estimator (decision tree type) to a dnn
  32. '''
  33. def __init__(self, estimator):
  34. super(Net, self).__init__()
  35. n_nodes = estimator.tree_.node_count
  36. children_left = estimator.tree_.children_left # left_child, id of the left child of the node
  37. children_right = estimator.tree_.children_right # right_child, id of the right child of the node
  38. feature = estimator.tree_.feature # feature, feature used for splitting the node
  39. threshold = estimator.tree_.threshold # threshold, threshold value at the node
  40. num_leaves = estimator.tree_.n_leaves
  41. num_non_leaves = estimator.tree_.node_count - num_leaves
  42. node_depth, is_leaves = self.calc_depths_and_leaves(n_nodes, children_left, children_right)
  43. self.values = estimator.tree_.value
  44. self.all_leaf_paths = {}
  45. self.calc_all_leaf_paths(0, n_nodes, children_left, children_right, running_list=[]) # set all_leaf_paths
  46. # initialize layers to zero
  47. self.layers = nn.Sequential(
  48. nn.Linear(estimator.n_features_, num_non_leaves),
  49. nn.Linear(num_non_leaves, num_leaves),
  50. nn.Linear(num_leaves, 1, bias=False)
  51. )
  52. for i in range(2):
  53. self.layers[i].weight.data *= 0
  54. self.layers[i].bias.data *= 0
  55. # set the first layer
  56. nonleaf_node_to_nonleaf_neuron_num = {} # np.zeros(num_non_leaves)
  57. nonleaf_neuron_num = 0
  58. for i in range(n_nodes):
  59. if not is_leaves[i]:
  60. self.layers[0].weight.data[nonleaf_neuron_num, feature[i]] = 1
  61. self.layers[0].bias.data[nonleaf_neuron_num] = -threshold[i]
  62. nonleaf_node_to_nonleaf_neuron_num[i] = nonleaf_neuron_num
  63. nonleaf_neuron_num += 1
  64. # set the 2nd + 3rd layer
  65. for leaf_neuron_num, leaf_idx in enumerate(sorted(self.all_leaf_paths.keys())):
  66. path = self.all_leaf_paths[leaf_idx]
  67. # 2nd lay
  68. for (nonleaf_node, sign) in path:
  69. self.layers[1].weight.data[leaf_neuron_num,
  70. nonleaf_node_to_nonleaf_neuron_num[
  71. nonleaf_node]] = sign # num_leaves x num_non_leaves
  72. self.layers[1].bias.data[leaf_neuron_num] = -1 * float(node_depth[leaf_idx])
  73. # 3rd lay
  74. self.layers[2].weight.data[0, leaf_neuron_num] = self.values[leaf_idx][
  75. 0, 0] # note, this will be multivariate for classification!
  76. # placeholder so class compiles
  77. def forward(self, x):
  78. # t0 = time.perf_counter()
  79. x = x.reshape(x.shape[0], -1)
  80. x = self.layers[0](x)
  81. t1 = time.perf_counter()
  82. x[x < 0] = -1
  83. x[x >= 0] = 1
  84. # t2 = time.perf_counter()
  85. x = self.layers[1](x)
  86. x = (x == 0).float()
  87. x = self.layers[2](x)
  88. # t3 = time.perf_counter()
  89. # print(f't1: {t1-t0:0.2e}, t2: {t2-t1:0.2e} t3: {t3-t2:0.2e}')
  90. return x
  91. def calc_depths_and_leaves(self, n_nodes, children_left, children_right):
  92. '''
  93. calculate numpy arrays representing the depth of each node and whether they are leaves or not
  94. '''
  95. # The tree structure can be traversed to compute various properties such
  96. # as the depth of each node and whether or not it is a leaf.
  97. node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
  98. is_leaves = np.zeros(shape=n_nodes, dtype=bool)
  99. # calculate node_depth and is_leaves
  100. stack = [(0, -1)] # seed is the root node id and its parent depth
  101. while len(stack) > 0:
  102. node_id, parent_depth = stack.pop()
  103. node_depth[node_id] = parent_depth + 1
  104. # If we have a test node
  105. if (children_left[node_id] != children_right[node_id]):
  106. stack.append((children_left[node_id], parent_depth + 1))
  107. stack.append((children_right[node_id], parent_depth + 1))
  108. else:
  109. is_leaves[node_id] = True
  110. return node_depth, is_leaves
  111. def calc_all_leaf_paths(self, node_idx, n_nodes, children_left, children_right, running_list):
  112. '''
  113. recursively store all leaf paths into a dictionary as tuples of (node_idxs, weight)
  114. weight is -1/+1 depending on if it left/right
  115. running_list is a reference to one list which is shared by all calls!
  116. '''
  117. # check if we are at a leaf
  118. if children_left[node_idx] == children_right[node_idx]:
  119. self.all_leaf_paths[node_idx] = deepcopy(running_list)
  120. else:
  121. running_list.append((node_idx, -1)) # assign weight of -1 to left
  122. self.calc_all_leaf_paths(children_left[node_idx], n_nodes, children_left, children_right, running_list)
  123. running_list.pop()
  124. running_list.append((node_idx, +1)) # assign weight of +1 to right
  125. self.calc_all_leaf_paths(children_right[node_idx], n_nodes, children_left, children_right, running_list)
  126. running_list.pop()
  127. def extract_util_np(self):
  128. b0 = self.layers[0].bias.data.numpy()
  129. idxs0 = self.layers[0].weight.data.argmax(dim=1).numpy()
  130. w1 = self.layers[1].weight.data.numpy().T
  131. b1 = self.layers[1].bias.data.numpy()
  132. num_leaves = self.layers[2].weight.shape[1]
  133. idxs2 = np.zeros(num_leaves) # leaf_neuron_num_to_val
  134. # iterate over leaves and map to values
  135. for leaf_neuron_num, i in enumerate(sorted(self.all_leaf_paths.keys())):
  136. idxs2[leaf_neuron_num] = self.values[i, 0, 0]
  137. return b0, idxs0, w1, b1, idxs2
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...