|
@@ -402,10 +402,10 @@ class LadderBlockLW(LadderBlockBase):
|
|
|
|
|
|
|
|
|
|
class NetOutput(ShelfNetModuleBase):
|
|
class NetOutput(ShelfNetModuleBase):
|
|
- def __init__(self, in_chan: int, mid_chan: int, classes_num: int):
|
|
|
|
|
|
+ def __init__(self, in_chan: int, mid_chan: int, num_classes: int):
|
|
super(NetOutput, self).__init__()
|
|
super(NetOutput, self).__init__()
|
|
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
|
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
|
- self.conv_out = nn.Conv2d(mid_chan, classes_num, kernel_size=3, bias=False,
|
|
|
|
|
|
+ self.conv_out = nn.Conv2d(mid_chan, num_classes, kernel_size=3, bias=False,
|
|
padding=1)
|
|
padding=1)
|
|
self.init_weight()
|
|
self.init_weight()
|
|
|
|
|
|
@@ -427,15 +427,15 @@ class ShelfNetBase(ShelfNetModuleBase):
|
|
ShelfNetBase - ShelfNet Base Generic Architecture
|
|
ShelfNetBase - ShelfNet Base Generic Architecture
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, backbone: ShelfResNetBackBone, planes: int, layers: int, classes_num: int = 21,
|
|
|
|
|
|
+ def __init__(self, backbone: ShelfResNetBackBone, planes: int, layers: int, num_classes: int = 21,
|
|
image_size: int = 512,
|
|
image_size: int = 512,
|
|
net_output_mid_channels_num: int = 64, arch_params: HpmStruct = None):
|
|
net_output_mid_channels_num: int = 64, arch_params: HpmStruct = None):
|
|
- self.classes_num = arch_params.num_classes if (arch_params and hasattr(arch_params, 'num_classes')) else classes_num
|
|
|
|
|
|
+ self.num_classes = arch_params.num_classes if (arch_params and hasattr(arch_params, 'num_classes')) else num_classes
|
|
self.image_size = arch_params.image_size if (arch_params and hasattr(arch_params, 'image_size')) else image_size
|
|
self.image_size = arch_params.image_size if (arch_params and hasattr(arch_params, 'image_size')) else image_size
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
self.net_output_mid_channels_num = net_output_mid_channels_num
|
|
self.net_output_mid_channels_num = net_output_mid_channels_num
|
|
- self.backbone = backbone(self.classes_num)
|
|
|
|
|
|
+ self.backbone = backbone(self.num_classes)
|
|
self.layers = layers
|
|
self.layers = layers
|
|
self.planes = planes
|
|
self.planes = planes
|
|
|
|
|
|
@@ -476,9 +476,9 @@ class ShelfNetHW(ShelfNetBase):
|
|
super().__init__(*args, **kwargs)
|
|
super().__init__(*args, **kwargs)
|
|
self.ladder = LadderBlockHW(planes=self.net_output_mid_channels_num, layers=self.layers)
|
|
self.ladder = LadderBlockHW(planes=self.net_output_mid_channels_num, layers=self.layers)
|
|
self.decoder = DecoderHW(planes=self.net_output_mid_channels_num, layers=self.layers)
|
|
self.decoder = DecoderHW(planes=self.net_output_mid_channels_num, layers=self.layers)
|
|
- self.se_layer = nn.Linear(self.net_output_mid_channels_num * 2 ** 3, self.classes_num)
|
|
|
|
- self.aux_head = FCNHead(1024, self.classes_num)
|
|
|
|
- self.final = nn.Conv2d(self.net_output_mid_channels_num, self.classes_num, 1)
|
|
|
|
|
|
+ self.se_layer = nn.Linear(self.net_output_mid_channels_num * 2 ** 3, self.num_classes)
|
|
|
|
+ self.aux_head = FCNHead(1024, self.num_classes)
|
|
|
|
+ self.final = nn.Conv2d(self.net_output_mid_channels_num, self.num_classes, 1)
|
|
|
|
|
|
# THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
|
|
# THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
|
|
net_out_planes = self.planes
|
|
net_out_planes = self.planes
|
|
@@ -638,7 +638,7 @@ class ShelfNet18_LW(ShelfNetLW):
|
|
mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
|
|
mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
|
|
|
|
|
|
self.net_output_list.append(
|
|
self.net_output_list.append(
|
|
- NetOutput(out_planes, mid_channels_num, self.classes_num))
|
|
|
|
|
|
+ NetOutput(out_planes, mid_channels_num, self.num_classes))
|
|
|
|
|
|
self.conv_out_list.append(
|
|
self.conv_out_list.append(
|
|
ConvBNReLU(out_planes * 2, out_planes, ks=1, stride=1, padding=0)
|
|
ConvBNReLU(out_planes * 2, out_planes, ks=1, stride=1, padding=0)
|
|
@@ -657,7 +657,7 @@ class ShelfNet34_LW(ShelfNetLW):
|
|
# IF IT'S THE FIRST LAYER THAN THE MID-CHANNELS NUM IS ACTUALLY self.planes
|
|
# IF IT'S THE FIRST LAYER THAN THE MID-CHANNELS NUM IS ACTUALLY self.planes
|
|
mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
|
|
mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
|
|
self.net_output_list.append(
|
|
self.net_output_list.append(
|
|
- NetOutput(net_out_planes, mid_channels_num, self.classes_num))
|
|
|
|
|
|
+ NetOutput(net_out_planes, mid_channels_num, self.num_classes))
|
|
|
|
|
|
net_out_planes *= 2
|
|
net_out_planes *= 2
|
|
|
|
|