Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#603 Support 1.13

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-489-torch_1.13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
  1. from typing import Type, Union, Mapping, Any
  2. import torch
  3. from torch import nn
  4. from super_gradients.modules import RepVGGBlock
  5. from super_gradients.modules.skip_connections import Residual
  6. class QARepVGGBlock(nn.Module):
  7. """
  8. QARepVGG (S3/S4) block from 'Make RepVGG Greater Again: A Quantization-aware Approach' (https://arxiv.org/pdf/2212.01593.pdf)
  9. It consists of three branches:
  10. 3x3: a branch of a 3x3 Convolution + BatchNorm
  11. 1x1: a branch of a 1x1 Convolution with bias
  12. identity: a Residual branch which will only be used if input channel == output channel and use_residual_connection is True
  13. (usually in all but the first block of each stage)
  14. BatchNorm is applied after summation of all three branches.
  15. In contrast to our implementation of RepVGGBlock, SE is applied AFTER NONLINEARITY in order to fuse Conv+Act in inference frameworks.
  16. This module converts to Conv+Act in a PTQ-friendly way by calling QARepVGGBlock.fuse_block_residual_branches().
  17. Has the same API as RepVGGBlock and is designed to be a plug-and-play replacement but is not compatible parameter-wise.
  18. Has less trainable parameters than RepVGGBlock because it has only 2 BatchNorms instead of 3.
  19. |
  20. |
  21. |---------------|---------------|
  22. | | |
  23. 3x3 1x1 |
  24. | | |
  25. BatchNorm +bias |
  26. | | |
  27. | *alpha |
  28. | | |
  29. |---------------+---------------|
  30. |
  31. BatchNorm
  32. |
  33. Act
  34. |
  35. SE
  36. """
  37. def __init__(
  38. self,
  39. in_channels: int,
  40. out_channels: int,
  41. stride: int = 1,
  42. dilation: int = 1,
  43. groups: int = 1,
  44. activation_type: Type[nn.Module] = nn.ReLU,
  45. activation_kwargs: Union[Mapping[str, Any], None] = None,
  46. se_type: Type[nn.Module] = nn.Identity,
  47. se_kwargs: Union[Mapping[str, Any], None] = None,
  48. build_residual_branches: bool = True,
  49. use_residual_connection: bool = True,
  50. use_alpha: bool = False,
  51. use_1x1_bias: bool = True,
  52. use_post_bn: bool = True,
  53. ):
  54. """
  55. :param in_channels: Number of input channels
  56. :param out_channels: Number of output channels
  57. :param activation_type: Type of the nonlinearity
  58. :param se_type: Type of the se block (Use nn.Identity to disable SE)
  59. :param stride: Output stride
  60. :param dilation: Dilation factor for 3x3 conv
  61. :param groups: Number of groups used in convolutions
  62. :param activation_kwargs: Additional arguments for instantiating activation module.
  63. :param se_kwargs: Additional arguments for instantiating SE module.
  64. :param build_residual_branches: Whether to initialize block with already fused parameters (for deployment)
  65. :param use_residual_connection: Whether to add input x to the output (Enabled in RepVGG, disabled in PP-Yolo)
  66. :param use_alpha: If True, enables additional learnable weighting parameter for 1x1 branch (PP-Yolo-E Plus)
  67. :param use_1x1_bias: If True, enables bias in the 1x1 convolution, authors don't mention it specifically
  68. :param use_post_bn: If True, adds BatchNorm after the sum of three branches (S4), if False, BatchNorm is not added (S3)
  69. """
  70. super().__init__()
  71. if activation_kwargs is None:
  72. activation_kwargs = {}
  73. if se_kwargs is None:
  74. se_kwargs = {}
  75. self.groups = groups
  76. self.in_channels = in_channels
  77. self.out_channels = out_channels
  78. self.stride = stride
  79. self.dilation = dilation
  80. self.activation_type = activation_type
  81. self.activation_kwargs = activation_kwargs
  82. self.se_type = se_type
  83. self.se_kwargs = se_kwargs
  84. self.use_residual_connection = use_residual_connection
  85. self.use_alpha = use_alpha
  86. self.use_1x1_bias = use_1x1_bias
  87. self.use_post_bn = use_post_bn
  88. self.nonlinearity = activation_type(**activation_kwargs)
  89. self.se = se_type(**se_kwargs)
  90. self.branch_3x3 = nn.Sequential()
  91. self.branch_3x3.add_module(
  92. "conv",
  93. nn.Conv2d(
  94. in_channels=in_channels,
  95. out_channels=out_channels,
  96. kernel_size=3,
  97. stride=stride,
  98. padding=dilation,
  99. groups=groups,
  100. bias=False,
  101. dilation=dilation,
  102. ),
  103. )
  104. self.branch_3x3.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
  105. self.branch_1x1 = nn.Conv2d(
  106. in_channels=in_channels,
  107. out_channels=out_channels,
  108. kernel_size=1,
  109. stride=stride,
  110. padding=0,
  111. groups=groups,
  112. bias=use_1x1_bias,
  113. )
  114. if use_residual_connection:
  115. assert out_channels == in_channels and stride == 1
  116. self.identity = Residual()
  117. input_dim = self.in_channels // self.groups
  118. self.id_tensor = torch.zeros((self.in_channels, input_dim, 3, 3))
  119. for i in range(self.in_channels):
  120. self.id_tensor[i, i % input_dim, 1, 1] = 1.0
  121. self.id_tensor = self.id_tensor.to(dtype=self.branch_1x1.weight.dtype, device=self.branch_1x1.weight.device)
  122. else:
  123. self.identity = None
  124. if use_alpha:
  125. self.alpha = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
  126. else:
  127. self.alpha = 1.0
  128. if self.use_post_bn:
  129. self.post_bn = nn.BatchNorm2d(num_features=out_channels)
  130. else:
  131. self.post_bn = nn.Identity()
  132. # placeholder to correctly register parameters
  133. self.rbr_reparam = nn.Conv2d(
  134. in_channels=self.branch_3x3.conv.in_channels,
  135. out_channels=self.branch_3x3.conv.out_channels,
  136. kernel_size=self.branch_3x3.conv.kernel_size,
  137. stride=self.branch_3x3.conv.stride,
  138. padding=self.branch_3x3.conv.padding,
  139. dilation=self.branch_3x3.conv.dilation,
  140. groups=self.branch_3x3.conv.groups,
  141. bias=True,
  142. )
  143. self.partially_fused = False
  144. self.fully_fused = False
  145. if not build_residual_branches:
  146. self.fuse_block_residual_branches()
  147. def forward(self, inputs):
  148. if self.fully_fused:
  149. return self.se(self.nonlinearity(self.rbr_reparam(inputs)))
  150. if self.partially_fused:
  151. return self.se(self.nonlinearity(self.post_bn(self.rbr_reparam(inputs))))
  152. if self.identity is None:
  153. id_out = 0.0
  154. else:
  155. id_out = self.identity(inputs)
  156. x_3x3 = self.branch_3x3(inputs)
  157. x_1x1 = self.alpha * self.branch_1x1(inputs)
  158. branches = x_3x3 + x_1x1 + id_out
  159. out = self.nonlinearity(self.post_bn(branches))
  160. se = self.se(out)
  161. return se
  162. def _get_equivalent_kernel_bias_for_branches(self):
  163. """
  164. Fuses the 3x3, 1x1 and identity branches into a single 3x3 conv layer
  165. """
  166. kernel3x3, bias3x3 = self._fuse_bn_tensor(
  167. self.branch_3x3.conv.weight,
  168. 0,
  169. self.branch_3x3.bn.running_mean,
  170. self.branch_3x3.bn.running_var,
  171. self.branch_3x3.bn.weight,
  172. self.branch_3x3.bn.bias,
  173. self.branch_3x3.bn.eps,
  174. )
  175. kernel1x1 = self._pad_1x1_to_3x3_tensor(self.branch_1x1.weight)
  176. bias1x1 = self.branch_1x1.bias if self.branch_1x1.bias is not None else 0
  177. kernelid = self.id_tensor if self.identity is not None else 0
  178. biasid = 0
  179. eq_kernel_3x3 = kernel3x3 + self.alpha * kernel1x1 + kernelid
  180. eq_bias_3x3 = bias3x3 + self.alpha * bias1x1 + biasid
  181. return eq_kernel_3x3, eq_bias_3x3
  182. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  183. """
  184. padding the 1x1 convolution weights with zeros to be able to fuse the 3x3 conv layer with the 1x1
  185. :param kernel1x1: weights of the 1x1 convolution
  186. :type kernel1x1:
  187. :return: padded 1x1 weights
  188. :rtype:
  189. """
  190. if kernel1x1 is None:
  191. return 0
  192. else:
  193. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  194. def _fuse_bn_tensor(self, kernel, bias, running_mean, running_var, gamma, beta, eps):
  195. std = torch.sqrt(running_var + eps)
  196. b = beta - gamma * running_mean / std
  197. A = gamma / std
  198. A_ = A.expand_as(kernel.transpose(0, -1)).transpose(0, -1)
  199. return kernel * A_, bias * A + b
  200. def full_fusion(self):
  201. """Fuse everything into Conv-Act-SE, non-trainable, parameters detached
  202. converts a qarepvgg block from training model (with branches) to deployment mode (vgg like model)
  203. :return:
  204. :rtype:
  205. """
  206. if self.fully_fused:
  207. return
  208. if not self.partially_fused:
  209. self.partial_fusion()
  210. if self.use_post_bn:
  211. eq_kernel, eq_bias = self._fuse_bn_tensor(
  212. self.rbr_reparam.weight,
  213. self.rbr_reparam.bias,
  214. self.post_bn.running_mean,
  215. self.post_bn.running_var,
  216. self.post_bn.weight,
  217. self.post_bn.bias,
  218. self.post_bn.eps,
  219. )
  220. self.rbr_reparam.weight.data = eq_kernel
  221. self.rbr_reparam.bias.data = eq_bias
  222. for para in self.parameters():
  223. para.detach_()
  224. if hasattr(self, "post_bn"):
  225. self.__delattr__("post_bn")
  226. self.partially_fused = False
  227. self.fully_fused = True
  228. def partial_fusion(self):
  229. """Fuse branches into a single kernel, leave post_bn unfused, leave parameters differentiable"""
  230. if self.partially_fused:
  231. return
  232. if self.fully_fused:
  233. # TODO: we actually can, all we need to do is insert the properly initialized post_bn back
  234. # init is not trivial, so not implemented for now
  235. raise NotImplementedError("QARepVGGBlock can't be converted to partially fused from fully fused")
  236. kernel, bias = self._get_equivalent_kernel_bias_for_branches()
  237. self.rbr_reparam.weight.data = kernel
  238. self.rbr_reparam.bias.data = bias
  239. self.__delattr__("branch_3x3")
  240. self.__delattr__("branch_1x1")
  241. if hasattr(self, "identity"):
  242. self.__delattr__("identity")
  243. if hasattr(self, "alpha"):
  244. self.__delattr__("alpha")
  245. if hasattr(self, "id_tensor"):
  246. self.__delattr__("id_tensor")
  247. self.partially_fused = True
  248. self.fully_fused = False
  249. def fuse_block_residual_branches(self):
  250. self.full_fusion()
  251. def from_repvgg(self, repvgg_block: RepVGGBlock):
  252. raise NotImplementedError
Discard
Tip!

Press p or to see the previous file or, n or to see the next file