Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
  1. import torch
  2. import torch.nn as nn
  3. from typing import Union, List
  4. from super_gradients.modules import ConvBNReLU
  5. from super_gradients.training.utils.module_utils import make_upsample_module
  6. from super_gradients.common import UpsampleMode
  7. from super_gradients.training.models.segmentation_models.stdc import AbstractSTDCBackbone, STDC1Backbone, STDC2Backbone
  8. from super_gradients.training.models.segmentation_models.common import SegmentationHead
  9. from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
  10. from super_gradients.training.utils import HpmStruct, get_param, torch_version_is_greater_or_equal
  11. from super_gradients.training.models.segmentation_models.context_modules import SPPM
  12. class UAFM(nn.Module):
  13. """
  14. Unified Attention Fusion Module, which uses mean and max values across the spatial dimensions.
  15. """
  16. def __init__(
  17. self,
  18. in_channels: int,
  19. skip_channels: int,
  20. out_channels: int,
  21. up_factor: int,
  22. upsample_mode: Union[UpsampleMode, str] = UpsampleMode.BILINEAR,
  23. align_corners: bool = False,
  24. ):
  25. """
  26. :params in_channels: num_channels of input feature map.
  27. :param skip_channels: num_channels of skip connection feature map.
  28. :param out_channels: num out channels after features fusion.
  29. :param up_factor: upsample scale factor of the input feature map.
  30. :param upsample_mode: see UpsampleMode for valid options.
  31. """
  32. super().__init__()
  33. self.conv_atten = nn.Sequential(
  34. ConvBNReLU(4, 2, kernel_size=3, padding=1, bias=False), ConvBNReLU(2, 1, kernel_size=3, padding=1, bias=False, use_activation=False)
  35. )
  36. self.proj_skip = nn.Identity() if skip_channels == in_channels else ConvBNReLU(skip_channels, in_channels, kernel_size=3, padding=1, bias=False)
  37. self.up_x = nn.Identity() if up_factor == 1 else make_upsample_module(scale_factor=up_factor, upsample_mode=upsample_mode, align_corners=align_corners)
  38. self.conv_out = ConvBNReLU(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
  39. def forward(self, x, skip):
  40. """
  41. :param x: input feature map to upsample before fusion.
  42. :param skip: skip connection feature map.
  43. """
  44. x = self.up_x(x)
  45. skip = self.proj_skip(skip)
  46. atten = torch.cat([*self._avg_max_spatial_reduce(x, use_concat=False), *self._avg_max_spatial_reduce(skip, use_concat=False)], dim=1)
  47. atten = self.conv_atten(atten)
  48. atten = torch.sigmoid(atten)
  49. out = x * atten + skip * (1 - atten)
  50. out = self.conv_out(out)
  51. return out
  52. @staticmethod
  53. def _avg_max_spatial_reduce(x, use_concat: bool = False):
  54. reduced = [torch.mean(x, dim=1, keepdim=True), torch.max(x, dim=1, keepdim=True)[0]]
  55. if use_concat:
  56. reduced = torch.cat(reduced, dim=1)
  57. return reduced
  58. class PPLiteSegEncoder(nn.Module):
  59. """
  60. Encoder for PPLiteSeg, include backbone followed by a context module.
  61. """
  62. def __init__(self, backbone: AbstractSTDCBackbone, projection_channels_list: List[int], context_module: nn.Module):
  63. super().__init__()
  64. self.backbone = backbone
  65. self.context_module = context_module
  66. feats_channels = backbone.get_backbone_output_number_of_channels()
  67. self.proj_convs = nn.ModuleList(
  68. [ConvBNReLU(feat_ch, proj_ch, kernel_size=3, padding=1, bias=False) for feat_ch, proj_ch in zip(feats_channels, projection_channels_list)]
  69. )
  70. self.projection_channels_list = projection_channels_list
  71. def get_output_number_of_channels(self) -> List[int]:
  72. channels_list = self.projection_channels_list
  73. if hasattr(self.context_module, "out_channels"):
  74. channels_list.append(self.context_module.out_channels)
  75. return channels_list
  76. def forward(self, x):
  77. feats = self.backbone(x)
  78. y = self.context_module(feats[-1])
  79. feats = [conv(f) for conv, f in zip(self.proj_convs, feats)]
  80. return feats + [y]
  81. class PPLiteSegDecoder(nn.Module):
  82. """
  83. PPLiteSegDecoder using UAFM blocks to fuse feature maps.
  84. """
  85. def __init__(self, encoder_channels: List[int], up_factors: List[int], out_channels: List[int], upsample_mode, align_corners: bool):
  86. super().__init__()
  87. # Make a copy of channels list, to prevent out of scope changes.
  88. encoder_channels = encoder_channels.copy()
  89. encoder_channels.reverse()
  90. in_channels = encoder_channels.pop(0)
  91. # TODO - assert argument length
  92. self.up_stages = nn.ModuleList()
  93. for skip_ch, up_factor, out_ch in zip(encoder_channels, up_factors, out_channels):
  94. self.up_stages.append(
  95. UAFM(
  96. in_channels=in_channels,
  97. skip_channels=skip_ch,
  98. out_channels=out_ch,
  99. up_factor=up_factor,
  100. upsample_mode=upsample_mode,
  101. align_corners=align_corners,
  102. )
  103. )
  104. in_channels = out_ch
  105. def forward(self, feats: List[torch.Tensor]):
  106. feats.reverse()
  107. x = feats.pop(0)
  108. for up_stage, skip in zip(self.up_stages, feats):
  109. x = up_stage(x, skip)
  110. return x
  111. class PPLiteSegBase(SegmentationModule):
  112. """
  113. The PP_LiteSeg implementation based on PaddlePaddle.
  114. The original article refers to "Juncai Peng, Yi Liu, Shiyu Tang, Yuying Hao, Lutao Chu,
  115. Guowei Chen, Zewu Wu, Zeyu Chen, Zhiliang Yu, Yuning Du, Qingqing Dang,Baohua Lai,
  116. Qiwen Liu, Xiaoguang Hu, Dianhai Yu, Yanjun Ma. PP-LiteSeg: A Superior Real-Time Semantic
  117. Segmentation Model. https://arxiv.org/abs/2204.02681".
  118. """
  119. def __init__(
  120. self,
  121. num_classes,
  122. backbone: AbstractSTDCBackbone,
  123. projection_channels_list: List[int],
  124. sppm_inter_channels: int,
  125. sppm_out_channels: int,
  126. sppm_pool_sizes: List[int],
  127. sppm_upsample_mode: Union[UpsampleMode, str],
  128. align_corners: bool,
  129. decoder_up_factors: List[int],
  130. decoder_channels: List[int],
  131. decoder_upsample_mode: Union[UpsampleMode, str],
  132. head_scale_factor: int,
  133. head_upsample_mode: Union[UpsampleMode, str],
  134. head_mid_channels: int,
  135. dropout: float,
  136. use_aux_heads: bool,
  137. aux_hidden_channels: List[int],
  138. aux_scale_factors: List[int],
  139. ):
  140. """
  141. :param backbone: Backbone nn.Module should implement the abstract class `AbstractSTDCBackbone`.
  142. :param projection_channels_list: channels list to project encoder features before fusing with the decoder
  143. stream.
  144. :param sppm_inter_channels: num channels in each sppm pooling branch.
  145. :param sppm_out_channels: The number of output channels after sppm module.
  146. :param sppm_pool_sizes: spatial output sizes of the pooled feature maps.
  147. :param sppm_upsample_mode: Upsample mode to original size after pooling.
  148. :param decoder_up_factors: list upsample factor per decoder stage.
  149. :param decoder_channels: list of num_channels per decoder stage.
  150. :param decoder_upsample_mode: upsample mode in decoder stages, see UpsampleMode for valid options.
  151. :param head_scale_factor: scale factor for final the segmentation head logits.
  152. :param head_upsample_mode: upsample mode to final prediction sizes, see UpsampleMode for valid options.
  153. :param head_mid_channels: num of hidden channels in segmentation head.
  154. :param use_aux_heads: set True when training, output extra Auxiliary feature maps from the encoder module.
  155. :param aux_hidden_channels: List of hidden channels in auxiliary segmentation heads.
  156. :param aux_scale_factors: list of uppsample factors for final auxiliary heads logits.
  157. """
  158. super().__init__(use_aux_heads=use_aux_heads)
  159. # Init Encoder
  160. backbone_out_channels = backbone.get_backbone_output_number_of_channels()
  161. assert len(backbone_out_channels) == len(projection_channels_list), (
  162. f"The length of backbone outputs ({backbone_out_channels}) should match the length of projection channels" f"({len(projection_channels_list)})."
  163. )
  164. context = SPPM(
  165. in_channels=backbone_out_channels[-1],
  166. inter_channels=sppm_inter_channels,
  167. out_channels=sppm_out_channels,
  168. pool_sizes=sppm_pool_sizes,
  169. upsample_mode=sppm_upsample_mode,
  170. align_corners=align_corners,
  171. )
  172. self.encoder = PPLiteSegEncoder(backbone=backbone, context_module=context, projection_channels_list=projection_channels_list)
  173. encoder_channels = self.encoder.get_output_number_of_channels()
  174. # Init Decoder
  175. self.decoder = PPLiteSegDecoder(
  176. encoder_channels=encoder_channels,
  177. up_factors=decoder_up_factors,
  178. out_channels=decoder_channels,
  179. upsample_mode=decoder_upsample_mode,
  180. align_corners=align_corners,
  181. )
  182. # Init Segmentation classification heads
  183. self.seg_head = nn.Sequential(
  184. SegmentationHead(in_channels=decoder_channels[-1], mid_channels=head_mid_channels, num_classes=num_classes, dropout=dropout),
  185. make_upsample_module(scale_factor=head_scale_factor, upsample_mode=head_upsample_mode, align_corners=align_corners),
  186. )
  187. # Auxiliary heads
  188. if self.use_aux_heads:
  189. encoder_out_channels = projection_channels_list
  190. self.aux_heads = nn.ModuleList(
  191. [
  192. nn.Sequential(
  193. SegmentationHead(backbone_ch, hidden_ch, num_classes, dropout=dropout),
  194. make_upsample_module(scale_factor=scale_factor, upsample_mode=head_upsample_mode, align_corners=align_corners),
  195. )
  196. for backbone_ch, hidden_ch, scale_factor in zip(encoder_out_channels, aux_hidden_channels, aux_scale_factors)
  197. ]
  198. )
  199. self.init_params()
  200. def _remove_auxiliary_heads(self):
  201. if hasattr(self, "aux_heads"):
  202. del self.aux_heads
  203. @property
  204. def backbone(self) -> nn.Module:
  205. """
  206. Support SG load backbone when training.
  207. """
  208. return self.encoder.backbone
  209. def forward(self, x):
  210. feats = self.encoder(x)
  211. if self.use_aux_heads:
  212. enc_feats = feats[:-1]
  213. x = self.decoder(feats)
  214. x = self.seg_head(x)
  215. if not self.use_aux_heads:
  216. return x
  217. aux_feats = [aux_head(feat) for feat, aux_head in zip(enc_feats, self.aux_heads)]
  218. return tuple([x] + aux_feats)
  219. def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
  220. """
  221. Custom param groups for training:
  222. - Different lr for backbone and the rest, if `multiply_head_lr` key is in `training_params`.
  223. """
  224. multiply_head_lr = get_param(training_params, "multiply_head_lr", 1)
  225. multiply_lr_params, no_multiply_params = self._separate_lr_multiply_params()
  226. param_groups = [
  227. {"named_params": no_multiply_params, "lr": lr, "name": "no_multiply_params"},
  228. {"named_params": multiply_lr_params, "lr": lr * multiply_head_lr, "name": "multiply_lr_params"},
  229. ]
  230. return param_groups
  231. def update_param_groups(self, param_groups: list, lr: float, epoch: int, iter: int, training_params: HpmStruct, total_batch: int) -> list:
  232. multiply_head_lr = get_param(training_params, "multiply_head_lr", 1)
  233. for param_group in param_groups:
  234. param_group["lr"] = lr
  235. if param_group["name"] == "multiply_lr_params":
  236. param_group["lr"] *= multiply_head_lr
  237. return param_groups
  238. def _separate_lr_multiply_params(self):
  239. """
  240. Separate backbone params from the rest.
  241. :return: iterators of groups named_parameters.
  242. """
  243. multiply_lr_params, no_multiply_params = {}, {}
  244. for name, param in self.named_parameters():
  245. if "encoder.backbone" in name:
  246. no_multiply_params[name] = param
  247. else:
  248. multiply_lr_params[name] = param
  249. return multiply_lr_params.items(), no_multiply_params.items()
  250. def prep_model_for_conversion(self, input_size: Union[tuple, list], stride_ratio: int = 32, **kwargs):
  251. if not torch_version_is_greater_or_equal(1, 11):
  252. raise RuntimeError("PPLiteSeg model ONNX export requires torch => 1.11, torch installed: " + str(torch.__version__))
  253. super().prep_model_for_conversion(input_size, **kwargs)
  254. if isinstance(self.encoder.context_module, SPPM):
  255. self.encoder.context_module.prep_model_for_conversion(input_size=input_size, stride_ratio=stride_ratio)
  256. def replace_head(self, new_num_classes: int, **kwargs):
  257. for module in self.modules():
  258. if isinstance(module, SegmentationHead):
  259. module.replace_num_classes(new_num_classes)
  260. class PPLiteSegB(PPLiteSegBase):
  261. def __init__(self, arch_params: HpmStruct):
  262. backbone = STDC2Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
  263. super().__init__(
  264. num_classes=get_param(arch_params, "num_classes"),
  265. backbone=backbone,
  266. projection_channels_list=[96, 128, 128],
  267. sppm_inter_channels=128,
  268. sppm_out_channels=128,
  269. sppm_pool_sizes=[1, 2, 4],
  270. sppm_upsample_mode="bilinear",
  271. align_corners=False,
  272. decoder_up_factors=[1, 2, 2],
  273. decoder_channels=[128, 96, 64],
  274. decoder_upsample_mode="bilinear",
  275. head_scale_factor=8,
  276. head_upsample_mode="bilinear",
  277. head_mid_channels=64,
  278. dropout=get_param(arch_params, "dropout", 0.0),
  279. use_aux_heads=get_param(arch_params, "use_aux_heads", False),
  280. aux_hidden_channels=[32, 64, 64],
  281. aux_scale_factors=[8, 16, 32],
  282. )
  283. class PPLiteSegT(PPLiteSegBase):
  284. def __init__(self, arch_params: HpmStruct):
  285. backbone = STDC1Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
  286. super().__init__(
  287. num_classes=get_param(arch_params, "num_classes"),
  288. backbone=backbone,
  289. projection_channels_list=[64, 128, 128],
  290. sppm_inter_channels=128,
  291. sppm_out_channels=128,
  292. sppm_pool_sizes=[1, 2, 4],
  293. sppm_upsample_mode="bilinear",
  294. align_corners=False,
  295. decoder_up_factors=[1, 2, 2],
  296. decoder_channels=[128, 64, 32],
  297. decoder_upsample_mode="bilinear",
  298. head_scale_factor=8,
  299. head_upsample_mode="bilinear",
  300. head_mid_channels=32,
  301. dropout=get_param(arch_params, "dropout", 0.0),
  302. use_aux_heads=get_param(arch_params, "use_aux_heads", False),
  303. aux_hidden_channels=[32, 64, 64],
  304. aux_scale_factors=[8, 16, 32],
  305. )
Discard
Tip!

Press p or to see the previous file or, n or to see the next file