Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
  1. """
  2. Shelfnet
  3. paper: https://arxiv.org/abs/1811.11254
  4. based on: https://github.com/juntang-zhuang/ShelfNet
  5. """
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from super_gradients.training.models.sg_module import SgModule
  10. from super_gradients.training.utils import HpmStruct
  11. from super_gradients.training.models.classification_models.resnet import BasicBlock, ResNet, Bottleneck
  12. class FCNHead(nn.Module):
  13. def __init__(self, in_channels, out_channels):
  14. super().__init__()
  15. inter_channels = in_channels // 4
  16. self.fcn = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
  17. nn.BatchNorm2d(inter_channels),
  18. nn.ReLU(),
  19. nn.Dropout2d(0.1, False),
  20. nn.Conv2d(inter_channels, out_channels, 1))
  21. def forward(self, x):
  22. return self.fcn(x)
  23. class ShelfBlock(nn.Module):
  24. def __init__(self, in_planes: int, planes: int, stride: int = 1, dropout: float = 0.25):
  25. """
  26. S-Block implementation from the ShelfNet paper
  27. :param in_planes: input planes
  28. :param planes: output planes
  29. :param stride: convolution stride
  30. :param dropout: dropout percentage
  31. """
  32. super().__init__()
  33. if in_planes != planes:
  34. self.conv0 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=True)
  35. self.relu0 = nn.ReLU(inplace=True)
  36. self.in_planes = in_planes
  37. self.planes = planes
  38. self.conv1 = nn.Conv2d(self.planes, self.planes, kernel_size=3, stride=stride, padding=1, bias=True)
  39. self.bn1 = nn.BatchNorm2d(self.planes)
  40. self.relu1 = nn.ReLU(inplace=True)
  41. self.dropout = nn.Dropout2d(p=dropout)
  42. self.bn2 = nn.BatchNorm2d(self.planes)
  43. self.relu2 = nn.ReLU(inplace=True)
  44. def forward(self, x):
  45. if self.in_planes != self.planes:
  46. x = self.conv0(x)
  47. x = self.relu0(x)
  48. out = self.conv1(x)
  49. out = self.bn1(out)
  50. out = self.relu1(out)
  51. out = self.dropout(out)
  52. out = self.conv1(out)
  53. out = self.bn2(out)
  54. out = out + x
  55. return self.relu2(out)
  56. class ShelfResNetBackBone(ResNet):
  57. """
  58. ShelfResNetBackBone - A class that Inherits from the original ResNet class and manipulates the forward pass,
  59. to create a backbone for the ShelfNet architecture
  60. """
  61. def __init__(self, block, num_blocks, num_classes=10, width_mult=1):
  62. super().__init__(block=block, num_blocks=num_blocks, num_classes=num_classes, width_mult=width_mult,
  63. backbone_mode=True)
  64. def forward(self, x):
  65. out = F.relu(self.bn1(self.conv1(x)))
  66. out = self.maxpool(out)
  67. feat4 = self.layer1(out) # 1/4
  68. feat8 = self.layer2(feat4) # 1/8
  69. feat16 = self.layer3(feat8) # 1/16
  70. feat32 = self.layer4(feat16) # 1/32
  71. return feat4, feat8, feat16, feat32
  72. class ShelfResNetBackBone18(ShelfResNetBackBone):
  73. def __init__(self, num_classes: int):
  74. super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
  75. class ShelfResNetBackBone34(ShelfResNetBackBone):
  76. def __init__(self, num_classes: int):
  77. super().__init__(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
  78. class ShelfResNetBackBone503343(ShelfResNetBackBone):
  79. def __init__(self, num_classes: int):
  80. super().__init__(Bottleneck, [3, 3, 4, 3], num_classes=num_classes)
  81. class ShelfResNetBackBone50(ShelfResNetBackBone):
  82. def __init__(self, num_classes: int):
  83. super().__init__(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
  84. class ShelfResNetBackBone101(ShelfResNetBackBone):
  85. def __init__(self, num_classes: int):
  86. super().__init__(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
  87. class ShelfNetModuleBase(SgModule):
  88. """
  89. ShelfNetModuleBase - Base class for the different Modules of the ShelfNet Architecture
  90. """
  91. def __init__(self):
  92. super().__init__()
  93. def forward(self, x):
  94. raise NotImplementedError
  95. def get_params(self):
  96. wd_params, nowd_params = [], []
  97. for name, module in self.named_modules():
  98. if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
  99. wd_params.append(module.weight)
  100. if module.bias is not None:
  101. nowd_params.append(module.bias)
  102. elif isinstance(module, nn.BatchNorm2d):
  103. nowd_params += list(module.parameters())
  104. return wd_params, nowd_params
  105. class ConvBNReLU(ShelfNetModuleBase):
  106. def __init__(self, in_chan: int, out_chan: int, ks: int = 3, stride: int = 1, padding: int = 1):
  107. super(ConvBNReLU, self).__init__()
  108. self.conv = nn.Conv2d(in_chan,
  109. out_chan,
  110. kernel_size=ks,
  111. stride=stride,
  112. padding=padding,
  113. bias=False)
  114. self.bn = nn.BatchNorm2d(out_chan)
  115. self.init_weight()
  116. def forward(self, x):
  117. x = self.conv(x)
  118. x = self.bn(x)
  119. x = F.relu(x)
  120. return x
  121. def init_weight(self):
  122. for ly in self.children():
  123. if isinstance(ly, nn.Conv2d):
  124. nn.init.kaiming_normal_(ly.weight, a=1)
  125. if ly.bias is not None:
  126. nn.init.constant_(ly.bias, 0)
  127. class DecoderBase(ShelfNetModuleBase):
  128. def __init__(self, planes: int, layers: int, kernel: int = 3, block=ShelfBlock):
  129. super().__init__()
  130. self.planes = planes
  131. self.layers = layers
  132. self.kernel = kernel
  133. self.padding = int((kernel - 1) / 2)
  134. self.inconv = block(planes, planes)
  135. # CREATE MODULE FOR BOTTOM BLOCK
  136. self.bottom = block(planes * (2 ** (layers - 1)), planes * (2 ** (layers - 1)))
  137. # CREATE MODULE LIST FOR UP BRANCH
  138. self.up_conv_list = nn.ModuleList()
  139. self.up_dense_list = nn.ModuleList()
  140. def forward(self, x):
  141. raise NotImplementedError
  142. class DecoderHW(DecoderBase):
  143. """
  144. DecoderHW - The Decoder for the Heavy-Weight ShelfNet Architecture
  145. """
  146. def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
  147. super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
  148. for i in range(0, layers - 1):
  149. self.up_conv_list.append(
  150. nn.ConvTranspose2d(planes * 2 ** (layers - 1 - i), planes * 2 ** max(0, layers - i - 2), kernel_size=3,
  151. stride=2, padding=1, output_padding=1, bias=True))
  152. self.up_dense_list.append(block(planes * 2 ** max(0, layers - i - 2), planes * 2 ** max(0, layers - i - 2)))
  153. def forward(self, x):
  154. # BOTTOM BRANCH
  155. out = self.bottom(x[-1])
  156. bottom = out
  157. # UP BRANCH
  158. up_out = []
  159. up_out.append(bottom)
  160. for j in range(0, self.layers - 1):
  161. out = self.up_conv_list[j](out) + x[self.layers - j - 2]
  162. out = self.up_dense_list[j](out)
  163. up_out.append(out)
  164. return up_out
  165. class DecoderLW(DecoderBase):
  166. """
  167. DecoderLW - The Decoder for the Light-Weight ShelfNet Architecture
  168. """
  169. def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
  170. super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
  171. for i in range(0, layers - 1):
  172. self.up_conv_list.append(
  173. AttentionRefinementModule(planes * 2 ** (layers - 1 - i), planes * 2 ** max(0, layers - i - 2)))
  174. self.up_dense_list.append(
  175. ConvBNReLU(in_chan=planes * 2 ** max(0, layers - i - 2), out_chan=planes * 2 ** max(0, layers - i - 2),
  176. ks=3, stride=1))
  177. def forward(self, x):
  178. # BOTTOM BRANCH
  179. out = self.bottom(x[-1])
  180. bottom = out
  181. # UP BRANCH
  182. up_out = []
  183. up_out.append(bottom)
  184. for j in range(0, self.layers - 1):
  185. out = self.up_conv_list[j](out)
  186. out_interpolate = F.interpolate(out, (out.size(2) * 2, out.size(3) * 2), mode='nearest')
  187. out = out_interpolate + x[self.layers - j - 2]
  188. out = self.up_dense_list[j](out)
  189. up_out.append(out)
  190. return up_out
  191. class AttentionRefinementModule(nn.Module):
  192. def __init__(self, in_chan, out_chan):
  193. super().__init__()
  194. self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
  195. self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
  196. self.bn_atten = nn.BatchNorm2d(out_chan)
  197. self.sigmoid_atten = nn.Sigmoid()
  198. self.init_weight()
  199. def forward(self, x):
  200. feat = self.conv(x)
  201. atten = F.avg_pool2d(feat, feat.size()[2:])
  202. atten = self.conv_atten(atten)
  203. atten = self.bn_atten(atten)
  204. atten = self.sigmoid_atten(atten)
  205. out = torch.mul(feat, atten)
  206. return out
  207. def init_weight(self):
  208. for ly in self.children():
  209. if isinstance(ly, nn.Conv2d):
  210. nn.init.kaiming_normal_(ly.weight, a=1)
  211. if ly.bias is not None:
  212. nn.init.constant_(ly.bias, 0)
  213. class LadderBlockBase(ShelfNetModuleBase):
  214. def __init__(self, planes: int, layers: int, kernel: int = 3, block=ShelfBlock):
  215. super().__init__()
  216. self.planes = planes
  217. self.layers = layers
  218. self.kernel = kernel
  219. self.padding = int((kernel - 1) / 2)
  220. self.inconv = block(planes, planes)
  221. # CREATE MODULE LIST FOR DOWN BRANCH
  222. self.down_module_list = nn.ModuleList()
  223. for i in range(0, layers - 1):
  224. self.down_module_list.append(block(planes * (2 ** i), planes * (2 ** i)))
  225. # USE STRIDED CONV INSTEAD OF POOLING
  226. self.down_conv_list = nn.ModuleList()
  227. for i in range(0, layers - 1):
  228. self.down_conv_list.append(
  229. nn.Conv2d(planes * 2 ** i, planes * 2 ** (i + 1), stride=2, kernel_size=kernel, padding=self.padding))
  230. # CREATE MODULE FOR BOTTOM BLOCK
  231. self.bottom = block(planes * (2 ** (layers - 1)), planes * (2 ** (layers - 1)))
  232. # CREATE MODULE LIST FOR UP BRANCH
  233. self.up_conv_list = nn.ModuleList()
  234. self.up_dense_list = nn.ModuleList()
  235. def forward(self, x):
  236. raise NotImplementedError
  237. class LadderBlockHW(LadderBlockBase):
  238. """
  239. LadderBlockHW - LadderBlock for the Heavy-Weight ShelfNet Architecture
  240. """
  241. def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
  242. super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
  243. for i in range(0, layers - 1):
  244. self.up_conv_list.append(nn.ConvTranspose2d(planes * 2 ** (layers - i - 1),
  245. planes * 2 ** max(0, layers - i - 2),
  246. kernel_size=3,
  247. stride=2,
  248. padding=1,
  249. output_padding=1,
  250. bias=True))
  251. self.up_dense_list.append(block(planes * 2 ** max(0, layers - i - 2), planes * 2 ** max(0, layers - i - 2)))
  252. def forward(self, x):
  253. out = self.inconv(x[-1])
  254. down_out = []
  255. # down branch
  256. for i in range(0, self.layers - 1):
  257. out = out + x[-i - 1]
  258. out = self.down_module_list[i](out)
  259. down_out.append(out)
  260. out = self.down_conv_list[i](out)
  261. out = F.relu(out)
  262. # bottom branch
  263. out = self.bottom(out)
  264. bottom = out
  265. # up branch
  266. up_out = []
  267. up_out.append(bottom)
  268. for j in range(0, self.layers - 1):
  269. out = self.up_conv_list[j](out) + down_out[self.layers - j - 2]
  270. out = self.up_dense_list[j](out)
  271. up_out.append(out)
  272. return up_out
  273. class LadderBlockLW(LadderBlockBase):
  274. """
  275. LadderBlockLW - LadderBlock for the Light-Weight ShelfNet Architecture
  276. """
  277. def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
  278. super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
  279. for i in range(0, layers - 1):
  280. self.up_conv_list.append(
  281. AttentionRefinementModule(planes * 2 ** (layers - 1 - i), planes * 2 ** max(0, layers - i - 2))
  282. )
  283. self.up_dense_list.append(
  284. ConvBNReLU(in_chan=planes * 2 ** max(0, layers - i - 2), out_chan=planes * 2 ** max(0, layers - i - 2),
  285. ks=3, stride=1))
  286. def forward(self, x):
  287. out = self.inconv(x[-1])
  288. down_out = []
  289. # DOWN BRANCH
  290. for i in range(0, self.layers - 1):
  291. out = out + x[-i - 1]
  292. out = self.down_module_list[i](out)
  293. down_out.append(out)
  294. out = self.down_conv_list[i](out)
  295. out = F.relu(out)
  296. # BOTTOM BRANCH
  297. out = self.bottom(out)
  298. bottom = out
  299. # UP BRANCH
  300. up_out = []
  301. up_out.append(bottom)
  302. for j in range(0, self.layers - 1):
  303. out = self.up_conv_list[j](out)
  304. out = F.interpolate(out, (out.size(2) * 2, out.size(3) * 2), mode='nearest') + down_out[self.layers - j - 2]
  305. out = self.up_dense_list[j](out)
  306. up_out.append(out)
  307. return up_out
  308. class NetOutput(ShelfNetModuleBase):
  309. def __init__(self, in_chan: int, mid_chan: int, num_classes: int):
  310. super(NetOutput, self).__init__()
  311. self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
  312. self.conv_out = nn.Conv2d(mid_chan, num_classes, kernel_size=3, bias=False,
  313. padding=1)
  314. self.init_weight()
  315. def forward(self, x):
  316. x = self.conv(x)
  317. x = self.conv_out(x)
  318. return x
  319. def init_weight(self):
  320. for ly in self.children():
  321. if isinstance(ly, nn.Conv2d):
  322. nn.init.kaiming_normal_(ly.weight, a=1)
  323. if ly.bias is not None:
  324. nn.init.constant_(ly.bias, 0)
  325. class ShelfNetBase(ShelfNetModuleBase):
  326. """
  327. ShelfNetBase - ShelfNet Base Generic Architecture
  328. """
  329. def __init__(self, backbone: ShelfResNetBackBone, planes: int, layers: int, num_classes: int = 21,
  330. image_size: int = 512,
  331. net_output_mid_channels_num: int = 64, arch_params: HpmStruct = None):
  332. self.num_classes = arch_params.num_classes if (arch_params and hasattr(arch_params, 'num_classes')) else num_classes
  333. self.image_size = arch_params.image_size if (arch_params and hasattr(arch_params, 'image_size')) else image_size
  334. super().__init__()
  335. self.net_output_mid_channels_num = net_output_mid_channels_num
  336. self.backbone = backbone(self.num_classes)
  337. self.layers = layers
  338. self.planes = planes
  339. # INITIALIZE WITH AUXILARY HEAD OUTPUTS ONN -> TURN IT OFF TO RUN A FORWARD PASS WITHOUT THE AUXILARY HEADS
  340. self.auxilary_head_outputs = True
  341. # DECODER AND LADDER SHOULD BE IMPLEMENTED BY THE INHERITING CLASS
  342. self.decoder = None
  343. self.ladder = None
  344. # BUILD THE CONV_OUT LIST BASED ON THE AMOUNT OF LAYERS IN THE SHELFNET
  345. self.conv_out_list = torch.nn.ModuleList()
  346. def forward(self, x):
  347. raise NotImplementedError
  348. def update_param_groups(self, param_groups: list, lr: float, epoch: int, iter: int, training_params: HpmStruct,
  349. total_batch: int) \
  350. -> list:
  351. """
  352. update_optimizer_for_param_groups - Updates the specific parameters with different LR
  353. """
  354. # LEARNING RATE FOR THE BACKBONE IS lr
  355. param_groups[0]['lr'] = lr
  356. for i in range(1, len(param_groups)):
  357. # LEARNING RATE FOR OTHER SHELFNET PARAMS IS lr * 10
  358. param_groups[i]['lr'] = lr * 10
  359. return param_groups
  360. class ShelfNetHW(ShelfNetBase):
  361. """
  362. ShelfNetHW - Heavy-Weight Version of ShelfNet
  363. """
  364. def __init__(self, *args, **kwargs):
  365. super().__init__(*args, **kwargs)
  366. self.ladder = LadderBlockHW(planes=self.net_output_mid_channels_num, layers=self.layers)
  367. self.decoder = DecoderHW(planes=self.net_output_mid_channels_num, layers=self.layers)
  368. self.se_layer = nn.Linear(self.net_output_mid_channels_num * 2 ** 3, self.num_classes)
  369. self.aux_head = FCNHead(1024, self.num_classes)
  370. self.final = nn.Conv2d(self.net_output_mid_channels_num, self.num_classes, 1)
  371. # THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
  372. net_out_planes = self.planes
  373. mid_channels_num = self.net_output_mid_channels_num
  374. # INITIALIZE THE conv_out_list
  375. for i in range(self.layers):
  376. self.conv_out_list.append(
  377. ConvBNReLU(in_chan=net_out_planes, out_chan=mid_channels_num, ks=1, padding=0))
  378. mid_channels_num *= 2
  379. net_out_planes *= 2
  380. def forward(self, x):
  381. image_size = x.size()[2:]
  382. backbone_features_list = list(self.backbone(x))
  383. conv_bn_relu_results_list = []
  384. for feature, conv_bn_relu in zip(backbone_features_list, self.conv_out_list):
  385. out = conv_bn_relu(feature)
  386. conv_bn_relu_results_list.append(out)
  387. decoder_out_list = self.decoder(conv_bn_relu_results_list)
  388. ladder_out_list = self.ladder(decoder_out_list)
  389. preds = [self.final(ladder_out_list[-1])]
  390. # SE_LOSS ENCODING
  391. enc = F.max_pool2d(ladder_out_list[0], kernel_size=ladder_out_list[0].size()[2:])
  392. enc = torch.squeeze(enc, -1)
  393. enc = torch.squeeze(enc, -1)
  394. se = self.se_layer(enc)
  395. preds.append(se)
  396. # UP SAMPLING THE TOP LAYER FOR PREDICTION
  397. preds[0] = F.interpolate(preds[0], image_size, mode='bilinear', align_corners=True)
  398. # AUXILARY HEAD OUTPUT (ONLY RELEVANT FOR LOSS CALCULATION) - USE self.auxilary_head_outputs=FALSE FOR INFERENCE
  399. if self.auxilary_head_outputs or self.training:
  400. aux_out = self.aux_head(backbone_features_list[2])
  401. aux_out = F.interpolate(aux_out, image_size, mode='bilinear', align_corners=True)
  402. preds.append(aux_out)
  403. return tuple(preds)
  404. else:
  405. return preds[0]
  406. def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
  407. """
  408. initialize_optimizer_for_model_param_groups - Initializes the weights of the optimizer
  409. Initializes the Backbone, the Output and the Auxilary Head
  410. differently
  411. :param optimizer_cls: The nn.optim (optimizer class) to initialize
  412. :param lr: lr to set for the optimizer
  413. :param training_params:
  414. :return: list of dictionaries with named params and optimizer attributes
  415. """
  416. # OPTIMIZER PARAMETER GROUPS
  417. params_list = []
  418. # OPTIMIZE BACKBONE USING DIFFERENT LR
  419. params_list.append({'named_params': self.backbone.named_parameters(), 'lr': lr})
  420. # OPTIMIZE MAIN SHELFNET ARCHITECTURE LAYERS
  421. params_list.append({'named_params': list(self.ladder.named_parameters()) + list(
  422. self.decoder.named_parameters()) + list(self.se_layer.named_parameters()) + list(
  423. self.conv_out_list.named_parameters()) + list(self.final.named_parameters()) + list(
  424. self.aux_head.named_parameters()), 'lr': lr * 10})
  425. return params_list
  426. class ShelfNetLW(ShelfNetBase):
  427. """
  428. ShelfNetLW - Light-Weight Implementation for ShelfNet
  429. """
  430. def __init__(self, *args, **kwargs):
  431. super().__init__(*args, **kwargs)
  432. self.net_output_list = nn.ModuleList()
  433. self.ladder = LadderBlockLW(planes=self.planes, layers=self.layers)
  434. self.decoder = DecoderLW(planes=self.planes, layers=self.layers)
  435. def forward(self, x):
  436. H, W = x.size()[2:]
  437. # SHELFNET LW ARCHITECTURE USES ONLY LAST 3 PARTIAL OUTPUTs OF THE BACKBONE'S 4 OUTPUT LAYERS
  438. backbone_features_tuple = self.backbone(x)[1:]
  439. if isinstance(self, ShelfNet18_LW):
  440. # FOR SHELFNET18 USE 1x1 CONVS AFTER THE BACKBONE'S FORWARD PASS TO MANIPULATE THE CHANNELS FOR THE DECODER
  441. conv_bn_relu_results_list = []
  442. for feature, conv_bn_relu in zip(backbone_features_tuple, self.conv_out_list):
  443. out = conv_bn_relu(feature)
  444. conv_bn_relu_results_list.append(out)
  445. else:
  446. # FOR SHELFNET34 THE CHANNELS ARE ALREADY ALIGNED
  447. conv_bn_relu_results_list = list(backbone_features_tuple)
  448. decoder_out_list = self.decoder(conv_bn_relu_results_list)
  449. ladder_out_list = self.ladder(decoder_out_list)
  450. # GET THE LAST ELEMENTS OF THE LADDER_BLOCK BASED ON THE AMOUNT OF SHELVES IN THE ARCHITECTURE AND REVERSE LIST
  451. feat_cp_list = list(reversed(ladder_out_list[(-1 * self.layers):]))
  452. feat_out = self.net_output_list[0](feat_cp_list[0])
  453. feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
  454. if self.auxilary_head_outputs or self.training:
  455. features_out_list = [feat_out]
  456. for conv_output_layer, feat_cp in zip(self.net_output_list[1:], feat_cp_list[1:]):
  457. feat_out_res = conv_output_layer(feat_cp)
  458. feat_out_res = F.interpolate(feat_out_res, (H, W), mode='bilinear', align_corners=True)
  459. features_out_list.append(feat_out_res)
  460. return tuple(features_out_list)
  461. else:
  462. # THIS DOES NOT CALCULATE THE AUXILARY HEADS THAT ARE CRITICAL FOR THE LOSS (USED MAINLY FOR INFERENCE)
  463. return feat_out
  464. def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
  465. """
  466. initialize_optimizer_for_model_param_groups - Initializes the optimizer group params, with 10x learning rate
  467. for all but the backbone
  468. :param lr: lr to set for the backbone
  469. :param training_params:
  470. :return: list of dictionaries with named params and optimizer attributes
  471. """
  472. # OPTIMIZER PARAMETER GROUPS
  473. params_list = []
  474. # OPTIMIZE BACKBONE USING DIFFERENT LR
  475. params_list.append({'named_params': self.backbone.named_parameters(), 'lr': lr})
  476. # OPTIMIZE MAIN SHELFNET ARCHITECTURE LAYERS
  477. params_list.append({'named_params': list(self.ladder.named_parameters()) + list(
  478. self.decoder.named_parameters()) + list(
  479. self.conv_out_list.named_parameters()), 'lr': lr * 10})
  480. return params_list
  481. class ShelfNet18_LW(ShelfNetLW):
  482. def __init__(self, *args, **kwargs):
  483. super().__init__(backbone=ShelfResNetBackBone18, planes=64, layers=3, *args, **kwargs)
  484. # INITIALIZE THE net_output_list AND THE conv_out LIST
  485. out_planes = self.planes
  486. for i in range(self.layers):
  487. # THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
  488. mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
  489. self.net_output_list.append(
  490. NetOutput(out_planes, mid_channels_num, self.num_classes))
  491. self.conv_out_list.append(
  492. ConvBNReLU(out_planes * 2, out_planes, ks=1, stride=1, padding=0)
  493. )
  494. out_planes *= 2
  495. class ShelfNet34_LW(ShelfNetLW):
  496. def __init__(self, *args, **kwargs):
  497. super().__init__(backbone=ShelfResNetBackBone34, planes=128, layers=3, *args, **kwargs)
  498. # INITIALIZE THE net_output_list
  499. net_out_planes = self.planes
  500. for i in range(self.layers):
  501. # IF IT'S THE FIRST LAYER THAN THE MID-CHANNELS NUM IS ACTUALLY self.planes
  502. mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
  503. self.net_output_list.append(
  504. NetOutput(net_out_planes, mid_channels_num, self.num_classes))
  505. net_out_planes *= 2
  506. class ShelfNet503343(ShelfNetHW):
  507. def __init__(self, *args, **kwargs):
  508. super().__init__(backbone=ShelfResNetBackBone503343, planes=256, layers=4, *args, **kwargs)
  509. class ShelfNet50(ShelfNetHW):
  510. def __init__(self, *args, **kwargs):
  511. super().__init__(backbone=ShelfResNetBackBone50, planes=256, layers=4, *args, **kwargs)
  512. class ShelfNet101(ShelfNetHW):
  513. def __init__(self, *args, **kwargs):
  514. super().__init__(backbone=ShelfResNetBackBone101, planes=256, layers=4, *args, **kwargs)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file