1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
|
- """
- Implementation of paper: "Rethinking BiSeNet For Real-time Semantic Segmentation", https://arxiv.org/abs/2104.13188
- Based on original implementation: https://github.com/MichaelFan01/STDC-Seg, cloned 23/08/2021, commit 59ff37f
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from super_gradients.training.models import SgModule
- from super_gradients.training.utils import get_param, HpmStruct
- from super_gradients.training.utils.module_utils import ConvBNReLU
- from typing import Union, List
- from abc import ABC, abstractmethod
- # default STDC argument as paper.
- STDC_SEG_DEFAULT_ARGS = {"context_fuse_channels": 128,
- "ffm_channels": 256,
- "aux_head_channels": 64,
- "detail_head_channels": 64}
- class STDCBlock(nn.Module):
- """
- STDC building block, known as Short Term Dense Concatenate module.
- In STDC module, the kernel size of first block is 1, and the rest of them are simply set as 3.
- Args:
- steps (int): The total number of convs in this module, 1 conv 1x1 and (steps - 1) conv3x3.
- """
- def __init__(self, in_channels: int, out_channels: int, steps: int, stride: int = 1):
- super(STDCBlock, self).__init__()
- assert steps in [2, 3, 4], f"only 2, 3, 4 steps number are supported, found: {steps}"
- self.stride = stride
- self.conv_list = nn.ModuleList()
- # build first step conv 1x1.
- self.conv_list.append(ConvBNReLU(in_channels, out_channels // 2, kernel_size=1, bias=False))
- # avg pool in skip if stride = 2.
- self.skip_step1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) if stride == 2 else nn.Identity()
- in_channels = out_channels // 2
- mid_channels = in_channels
- # build rest conv3x3 layers.
- for idx in range(1, steps):
- if idx < steps - 1:
- mid_channels //= 2
- conv = ConvBNReLU(in_channels, mid_channels, kernel_size=3, stride=1, padding=1, bias=False)
- self.conv_list.append(conv)
- in_channels = mid_channels
- # add dw conv before second step for down sample if stride = 2.
- if stride == 2:
- self.conv_list[1] = nn.Sequential(
- ConvBNReLU(out_channels // 2, out_channels // 2, kernel_size=3, stride=2, padding=1,
- groups=out_channels // 2, use_activation=False, bias=False),
- self.conv_list[1])
- def forward(self, x):
- out_list = []
- # run first conv
- x = self.conv_list[0](x)
- out_list.append(self.skip_step1(x))
- for conv in self.conv_list[1:]:
- x = conv(x)
- out_list.append(x)
- out = torch.cat(out_list, dim=1)
- return out
- class AbstractSTDCBackbone(nn.Module, ABC):
- """
- All backbones for STDC segmentation models must implement this class.
- """
- def validate_backbone(self):
- assert len(self.get_backbone_output_number_of_channels()) == 3,\
- f"Backbone for STDC segmentation must output 3 feature maps," \
- f" found: {len(self.get_backbone_output_number_of_channels())}."
- @abstractmethod
- def get_backbone_output_number_of_channels(self) -> List[int]:
- """
- :return: list on stages num channels.
- """
- raise NotImplementedError()
- class STDCBackbone(AbstractSTDCBackbone):
- def __init__(self,
- block_types: list,
- ch_widths: list,
- num_blocks: list,
- stdc_steps: int = 4,
- in_channels: int = 3,
- out_down_ratios: Union[tuple, list] = (32,)):
- """
- :param block_types: list of block type for each stage, supported `conv` for ConvBNRelu with 3x3 kernel.
- :param ch_widths: list of output num of channels for each stage.
- :param num_blocks: list of the number of repeating blocks in each stage.
- :param stdc_steps: num of convs steps in each block.
- :param in_channels: num channels of the input image.
- :param out_down_ratios: down ratio of output feature maps required from the backbone,
- default (32,) for classification.
- """
- super(STDCBackbone, self).__init__()
- assert len(block_types) == len(ch_widths) == len(num_blocks),\
- f"STDC architecture configuration, block_types, ch_widths, num_blocks, must be defined for the same number" \
- f" of stages, found: {len(block_types)} for block_type, {len(ch_widths)} for ch_widths, " \
- f"{len(num_blocks)} for num_blocks"
- self.out_widths = []
- self.stages = nn.ModuleDict()
- self.out_stage_keys = []
- down_ratio = 2
- for block_type, width, blocks in zip(block_types, ch_widths, num_blocks):
- block_name = f"block_s{down_ratio}"
- self.stages[block_name] = self._make_stage(in_channels=in_channels, out_channels=width,
- block_type=block_type, num_blocks=blocks, stdc_steps=stdc_steps)
- if down_ratio in out_down_ratios:
- self.out_stage_keys.append(block_name)
- self.out_widths.append(width)
- in_channels = width
- down_ratio *= 2
- def _make_stage(self,
- in_channels: int,
- out_channels: int,
- block_type: str,
- num_blocks: int,
- stdc_steps: int = 4):
- """
- :param in_channels: input channels of stage.
- :param out_channels: output channels of stage.
- :param block_type: stage building block, supported `conv` for 3x3 ConvBNRelu, or `stdc` for STDCBlock.
- :param num_blocks: num of blocks in each stage.
- :param stdc_steps: number of conv3x3 steps in each STDC block, referred as `num blocks` in paper.
- :return: nn.Module
- """
- if block_type == "conv":
- block = ConvBNReLU
- kwargs = {"kernel_size": 3, "padding": 1, "bias": False}
- elif block_type == "stdc":
- block = STDCBlock
- kwargs = {"steps": stdc_steps}
- else:
- raise ValueError(f"Block type not supported: {block_type}, excepted: `conv` or `stdc`")
- # first block to apply stride 2.
- blocks = nn.ModuleList([
- block(in_channels, out_channels, stride=2, **kwargs)
- ])
- # build rest of blocks
- for i in range(num_blocks - 1):
- blocks.append(block(out_channels, out_channels, stride=1, **kwargs))
- return nn.Sequential(*blocks)
- def forward(self, x):
- outputs = []
- for stage_name, stage in self.stages.items():
- x = stage(x)
- if stage_name in self.out_stage_keys:
- outputs.append(x)
- return tuple(outputs)
- def get_backbone_output_number_of_channels(self) -> List[int]:
- return self.out_widths
- class STDCClassificationBase(SgModule):
- """
- Base module for classification model based on STDCs backbones
- """
- def __init__(self,
- backbone: STDCBackbone,
- num_classes: int,
- dropout: float):
- super(STDCClassificationBase, self).__init__()
- self.backbone = backbone
- last_channels = self.backbone.out_widths[-1]
- head_channels = max(1024, last_channels)
- self.conv_last = ConvBNReLU(last_channels, head_channels, 1, 1, bias=False)
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Linear(head_channels, head_channels, bias=False)
- self.relu = nn.ReLU(inplace=True)
- self.dropout = nn.Dropout(p=dropout)
- self.linear = nn.Linear(head_channels, num_classes, bias=False)
- self.init_params()
- def init_params(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity="relu")
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, std=0.001)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- out = self.backbone(x)[-1]
- # original implementation, why to use power?
- out = self.conv_last(out).pow(2)
- out = self.gap(out).flatten(1)
- out = self.fc(out)
- out = self.relu(out)
- out = self.dropout(out)
- out = self.linear(out)
- return out
- class STDCClassification(STDCClassificationBase):
- def __init__(self, arch_params: HpmStruct):
- super().__init__(backbone=get_param(arch_params, "backbone"),
- num_classes=get_param(arch_params, "num_classes"),
- dropout=get_param(arch_params, "dropout", 0.2))
- class AttentionRefinementModule(nn.Module):
- """
- AttentionRefinementModule to apply on the last two backbone stages.
- """
- def __init__(self, in_channels: int, out_channels: int):
- super(AttentionRefinementModule, self).__init__()
- self.conv_first = ConvBNReLU(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
- self.attention_block = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- ConvBNReLU(out_channels, out_channels, kernel_size=1, bias=False, use_activation=False),
- nn.Sigmoid()
- )
- def forward(self, x):
- x = self.conv_first(x)
- y = self.attention_block(x)
- return torch.mul(x, y)
- class FeatureFusionModule(nn.Module):
- """
- Fuse features from higher resolution aka, spatial feature map with features from lower resolution with high
- semantic information aka, context feature map.
- :param spatial_channels: num channels of input from spatial path.
- :param context_channels: num channels of input from context path.
- :param out_channels: num channels of feature fusion module.
- """
- def __init__(self, spatial_channels: int, context_channels: int, out_channels: int):
- super(FeatureFusionModule, self).__init__()
- self.pw_conv = ConvBNReLU(spatial_channels + context_channels, out_channels, kernel_size=1, stride=1,
- bias=False)
- # TODO - used without bias in convolutions by mistake, try to reproduce with bias=True
- self.attention_block = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- ConvBNReLU(in_channels=out_channels, out_channels=out_channels // 4, kernel_size=1, use_normalization=False,
- bias=False),
- nn.Conv2d(in_channels=out_channels // 4, out_channels=out_channels, kernel_size=1, bias=False),
- nn.Sigmoid()
- )
- def forward(self, spatial_feats, context_feats):
- feat = torch.cat([spatial_feats, context_feats], dim=1)
- feat = self.pw_conv(feat)
- atten = self.attention_block(feat)
- feat_atten = torch.mul(feat, atten)
- feat_out = feat_atten + feat
- return feat_out
- class ContextEmbeddingOnline(nn.Module):
- """
- ContextEmbedding module that use global average pooling to 1x1 to extract context information, and then upsample
- to original input size.
- """
- def __init__(self, in_channels: int, out_channels: int):
- super(ContextEmbeddingOnline, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.context_embedding = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- ConvBNReLU(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
- )
- def forward(self, x):
- out_height, out_width = x.size()[2:]
- x = self.context_embedding(x)
- return F.interpolate(x, size=(out_height, out_width), mode='nearest')
- class ContextEmbeddingFixedSize(ContextEmbeddingOnline):
- """
- ContextEmbedding module that use a fixed size interpolation, supported with onnx conversion.
- Prevent slice/cast/shape operations in onnx conversion for applying interpolation.
- """
- def __init__(self, in_channels: int, out_channels: int, upsample_size: Union[list, tuple]):
- super(ContextEmbeddingFixedSize, self).__init__(in_channels, out_channels)
- self.context_embedding.add_module("upsample", nn.Upsample(scale_factor=upsample_size, mode="nearest"))
- @classmethod
- def from_context_embedding_online(cls, ce_online: ContextEmbeddingOnline, upsample_size: Union[list, tuple]):
- context = ContextEmbeddingFixedSize(in_channels=ce_online.in_channels, out_channels=ce_online.out_channels,
- upsample_size=upsample_size)
- # keep training mode state as original module
- context.train(ce_online.training)
- context.load_state_dict(ce_online.state_dict())
- return context
- def forward(self, x):
- return self.context_embedding(x)
- class ContextPath(nn.Module):
- """
- ContextPath in STDC output both the Spatial path and Context path. This module include a STDCBackbone and output
- the stage3 feature map with down_ratio = 8 as the spatial feature map, and context feature map which is a result of
- upsampling and fusion of context embedding, stage5 and stage4 after Arm modules, Which is also with same resolution
- of the spatial feature map, down_ration = 8.
- :param backbone: Backbone of type AbstractSTDCBackbone that return info about backbone output channels.
- :param fuse_channels: num channels of the fused context path.
- :param use_aux_heads: set True when training, output extra Auxiliary feature maps of the two last stages of the
- backbone.
- """
- def __init__(self,
- backbone: AbstractSTDCBackbone,
- fuse_channels: int,
- use_aux_heads: bool):
- super(ContextPath, self).__init__()
- self.use_aux_heads = use_aux_heads
- self.backbone = backbone
- # get num of channels for two last stages
- channels16, channels32 = self.backbone.get_backbone_output_number_of_channels()[-2:]
- self.context_embedding = ContextEmbeddingOnline(channels32, fuse_channels)
- self.arm32 = AttentionRefinementModule(channels32, fuse_channels)
- self.upsample32 = nn.Sequential(
- nn.Upsample(scale_factor=2, mode="nearest"),
- ConvBNReLU(fuse_channels, fuse_channels, kernel_size=3, padding=1, stride=1, bias=False)
- )
- self.arm16 = AttentionRefinementModule(channels16, fuse_channels)
- self.upsample16 = nn.Sequential(
- nn.Upsample(scale_factor=2, mode="nearest"),
- ConvBNReLU(fuse_channels, fuse_channels, kernel_size=3, padding=1, stride=1, bias=False)
- )
- def forward(self, x):
- feat8, feat16, feat32 = self.backbone(x)
- ce_feats = self.context_embedding(feat32)
- feat32_arm = self.arm32(feat32)
- feat32_arm = feat32_arm + ce_feats
- feat32_up = self.upsample32(feat32_arm)
- feat16_arm = self.arm16(feat16)
- feat16_arm = feat16_arm + feat32_up
- feat16_up = self.upsample16(feat16_arm)
- if self.use_aux_heads:
- return feat8, feat16_up, feat16, feat32
- return feat8, feat16_up
- class SegmentationHead(nn.Module):
- def __init__(self, in_channels: int, mid_channels: int, num_classes: int, dropout: float):
- super(SegmentationHead, self).__init__()
- self.seg_head = nn.Sequential(
- ConvBNReLU(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=False),
- nn.Dropout(dropout),
- nn.Conv2d(mid_channels, num_classes, kernel_size=1, bias=False)
- )
- def forward(self, x):
- return self.seg_head(x)
- def replace_num_classes(self, num_classes: int):
- """
- This method replace the last Conv Classification layer to output a different number of classes.
- Note that the weights of the new layers are random initiated.
- """
- old_cls_conv = self.seg_head[-1]
- self.seg_head[-1] = nn.Conv2d(old_cls_conv.in_channels, num_classes, kernel_size=1, bias=False)
- class STDCSegmentationBase(SgModule):
- """
- Base STDC Segmentation Module.
- :param backbone: Backbone of type AbstractSTDCBackbone that return info about backbone output channels.
- :param num_classes: num of dataset classes, exclude ignore label.
- :param context_fuse_channels: num of output channels in ContextPath ARM feature fusion.
- :param ffm_channels: num of output channels of Feature Fusion Module.
- :param aux_head_channels: Num of hidden channels in Auxiliary segmentation heads.
- :param detail_head_channels: Num of hidden channels in Detail segmentation heads.
- :param use_aux_heads: set True when training, attach Auxiliary and Detail heads. For compilation / inference mode
- set False.
- :param dropout: segmentation heads dropout.
- """
- def __init__(self,
- backbone: AbstractSTDCBackbone,
- num_classes: int,
- context_fuse_channels: int,
- ffm_channels: int,
- aux_head_channels: int,
- detail_head_channels: int,
- use_aux_heads: bool,
- dropout: float):
- super(STDCSegmentationBase, self).__init__()
- backbone.validate_backbone()
- self._use_aux_heads = use_aux_heads
- self.cp = ContextPath(backbone, context_fuse_channels, use_aux_heads=use_aux_heads)
- stage3_s8_channels, stage4_s16_channels, stage5_s32_channels = backbone.get_backbone_output_number_of_channels()
- self.ffm = FeatureFusionModule(spatial_channels=stage3_s8_channels, context_channels=context_fuse_channels,
- out_channels=ffm_channels)
- # Main segmentation head
- self.segmentation_head = nn.Sequential(
- SegmentationHead(ffm_channels, ffm_channels, num_classes, dropout=dropout),
- nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
- )
- if self._use_aux_heads:
- # Auxiliary heads
- self.aux_head_s16 = nn.Sequential(
- SegmentationHead(stage4_s16_channels, aux_head_channels, num_classes, dropout=dropout),
- nn.Upsample(scale_factor=16, mode="bilinear", align_corners=True)
- )
- self.aux_head_s32 = nn.Sequential(
- SegmentationHead(stage5_s32_channels, aux_head_channels, num_classes, dropout=dropout),
- nn.Upsample(scale_factor=32, mode="bilinear", align_corners=True)
- )
- # Detail head
- self.detail_head8 = nn.Sequential(
- SegmentationHead(stage3_s8_channels, detail_head_channels, 1, dropout=dropout),
- nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
- )
- self.init_params()
- def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwargs):
- """
- Prepare model for conversion, force use_aux_heads mode False and delete auxiliary and detail heads. Replace
- ContextEmbeddingOnline which cause compilation issues and not supported in some compilations,
- to ContextEmbeddingFixedSize.
- """
- # set to false and delete auxiliary and detail heads modules.
- self.use_aux_heads = False
- context_embedding_up_size = (input_size[-2] // 32, input_size[-1] // 32)
- self.cp.context_embedding = ContextEmbeddingFixedSize.from_context_embedding_online(self.cp.context_embedding,
- context_embedding_up_size)
- def _remove_auxiliary_and_detail_heads(self):
- attributes_to_delete = ["aux_head_s16", "aux_head_s32", "detail_head8"]
- for attr in attributes_to_delete:
- if hasattr(self, attr):
- delattr(self, attr)
- @property
- def use_aux_heads(self):
- return self._use_aux_heads
- @use_aux_heads.setter
- def use_aux_heads(self, use_aux: bool):
- """
- private setter for self._use_aux_heads, called every time an assignment to self._use_aux_heads is applied.
- if use_aux is False, `_remove_auxiliary_and_detail_heads` is called to delete auxiliary and detail heads.
- if use_aux is True, and self._use_aux_heads was already set to False a ValueError is raised, recreating
- aux and detail heads outside init method is not allowed, and the module should be recreated.
- """
- if use_aux is True and self._use_aux_heads is False:
- raise ValueError("Cant turn use_aux_heads from False to True, you should initiate the module again with"
- " `use_aux_heads=True`")
- if not use_aux:
- self._remove_auxiliary_and_detail_heads()
- self.cp.use_aux_heads = use_aux
- self._use_aux_heads = use_aux
- @property
- def backbone(self):
- """
- For Trainer load_backbone compatibility.
- """
- return self.cp.backbone
- def init_params(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity="relu")
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, std=0.001)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- cp_outs = self.cp(x)
- feat8, feat_cp8 = cp_outs[0], cp_outs[1]
- # fuse stage 3 with result of context path after ARM modules.
- feat_out = self.ffm(spatial_feats=feat8, context_feats=feat_cp8)
- feat_out = self.segmentation_head(feat_out)
- if not self.use_aux_heads:
- return feat_out
- feat16, feat32 = cp_outs[2], cp_outs[3]
- detail_out8 = self.detail_head8(feat8)
- aux_out_s16 = self.aux_head_s16(feat16)
- aux_out_s32 = self.aux_head_s32(feat32)
- return feat_out, aux_out_s32, aux_out_s16, detail_out8
- def replace_head(self, new_num_classes: int, **kwargs):
- ffm_channels = self.ffm.attention_block[-2].out_channels
- dropout = self.segmentation_head[0].seg_head[1].p
- # Output layer's replacement- first modules in the sequences are the SegmentationHead modules.
- self.segmentation_head[0] = SegmentationHead(ffm_channels, ffm_channels, new_num_classes, dropout=dropout)
- if self.use_aux_heads:
- stage3_s8_channels, stage4_s16_channels, stage5_s32_channels = self.backbone.get_backbone_output_number_of_channels()
- aux_head_channels = self.aux_head_s16[0].seg_head[-1].in_channels
- detail_head_channels = self.detail_head8[0].seg_head[-1].in_channels
- self.aux_head_s16[0] = SegmentationHead(stage4_s16_channels, aux_head_channels, new_num_classes, dropout=dropout)
- self.aux_head_s32[0] = SegmentationHead(stage5_s32_channels, aux_head_channels, new_num_classes, dropout=dropout)
- # Detail head
- self.detail_head8[0] = SegmentationHead(stage3_s8_channels, detail_head_channels, 1, dropout=dropout)
- def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
- """
- Custom param groups for STDC training:
- - Different lr for context path and heads, if `multiply_head_lr` key is in `training_params`.
- - Add extra Detail loss params to optimizer.
- """
- extra_train_params = training_params.loss.get_train_named_params() if hasattr(training_params.loss, "get_train_named_params") else None
- multiply_head_lr = get_param(training_params, "multiply_head_lr", 1)
- multiply_lr_params, no_multiply_params = self._separate_lr_multiply_params()
- param_groups = [{"named_params": no_multiply_params, "lr": lr, "name": "no_multiply_params"},
- {"named_params": multiply_lr_params, "lr": lr * multiply_head_lr, "name": "multiply_lr_params"}]
- if extra_train_params is not None:
- param_groups.append({"named_params": extra_train_params, "lr": lr, "weight_decay": 0.,
- "name": "detail_params"})
- return param_groups
- def update_param_groups(self, param_groups: list, lr: float, epoch: int, iter: int, training_params: HpmStruct,
- total_batch: int) -> list:
- multiply_head_lr = get_param(training_params, "multiply_head_lr", 1)
- for param_group in param_groups:
- param_group['lr'] = lr
- if param_group["name"] == "multiply_lr_params":
- param_group['lr'] *= multiply_head_lr
- return param_groups
- def _separate_lr_multiply_params(self):
- """
- Separate ContextPath params from the rest.
- :return: iterators of groups named_parameters.
- """
- multiply_lr_params, no_multiply_params = {}, {}
- for name, param in self.named_parameters():
- if "cp." in name:
- no_multiply_params[name] = param
- else:
- multiply_lr_params[name] = param
- return multiply_lr_params.items(), no_multiply_params.items()
- class CustomSTDCSegmentation(STDCSegmentationBase):
- """
- Fully customized STDC Segmentation factory module.
- """
- def __init__(self, arch_params: HpmStruct):
- super().__init__(backbone=get_param(arch_params, "backbone"),
- num_classes=get_param(arch_params, "num_classes"),
- context_fuse_channels=get_param(arch_params, "context_fuse_channels", 128),
- ffm_channels=get_param(arch_params, "ffm_channels", 256),
- aux_head_channels=get_param(arch_params, "aux_head_channels", 64),
- detail_head_channels=get_param(arch_params, "detail_head_channels", 64),
- use_aux_heads=get_param(arch_params, "use_aux_heads", True),
- dropout=get_param(arch_params, "dropout", 0.2))
- class STDC1Backbone(STDCBackbone):
- def __init__(self, in_channels: int = 3, out_down_ratios: Union[tuple, list] = (32,)):
- super().__init__(block_types=["conv", "conv", "stdc", "stdc", "stdc"],
- ch_widths=[32, 64, 256, 512, 1024], num_blocks=[1, 1, 2, 2, 2], stdc_steps=4,
- in_channels=in_channels, out_down_ratios=out_down_ratios)
- class STDC2Backbone(STDCBackbone):
- def __init__(self, in_channels: int = 3, out_down_ratios: Union[tuple, list] = (32,)):
- super().__init__(block_types=["conv", "conv", "stdc", "stdc", "stdc"],
- ch_widths=[32, 64, 256, 512, 1024], num_blocks=[1, 1, 4, 5, 3], stdc_steps=4,
- in_channels=in_channels, out_down_ratios=out_down_ratios)
- class STDC1Classification(STDCClassification):
- def __init__(self, arch_params: HpmStruct):
- backbone = STDC1Backbone(in_channels=get_param(arch_params, "input_channels", 3), out_down_ratios=(32,))
- arch_params.override(**{"backbone": backbone})
- super().__init__(arch_params)
- class STDC2Classification(STDCClassification):
- def __init__(self, arch_params: HpmStruct):
- backbone = STDC2Backbone(in_channels=get_param(arch_params, "input_channels", 3), out_down_ratios=(32,))
- arch_params.override(**{"backbone": backbone})
- super().__init__(arch_params)
- class STDC1Seg(CustomSTDCSegmentation):
- def __init__(self, arch_params: HpmStruct):
- backbone = STDC1Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
- custom_params = {"backbone": backbone, **STDC_SEG_DEFAULT_ARGS}
- arch_params.override(**custom_params)
- super().__init__(arch_params)
- class STDC2Seg(CustomSTDCSegmentation):
- def __init__(self, arch_params: HpmStruct):
- backbone = STDC2Backbone(in_channels=get_param(arch_params, "in_channels", 3), out_down_ratios=[8, 16, 32])
- custom_params = {"backbone": backbone, **STDC_SEG_DEFAULT_ARGS}
- arch_params.override(**custom_params)
- super().__init__(arch_params)
|