Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#426 Rename shelfnet classes_num param

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-rename_shelfnet_classesnum_param
@@ -402,10 +402,10 @@ class LadderBlockLW(LadderBlockBase):
 
 
 
 
 class NetOutput(ShelfNetModuleBase):
 class NetOutput(ShelfNetModuleBase):
-    def __init__(self, in_chan: int, mid_chan: int, classes_num: int):
+    def __init__(self, in_chan: int, mid_chan: int, num_classes: int):
         super(NetOutput, self).__init__()
         super(NetOutput, self).__init__()
         self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
         self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
-        self.conv_out = nn.Conv2d(mid_chan, classes_num, kernel_size=3, bias=False,
+        self.conv_out = nn.Conv2d(mid_chan, num_classes, kernel_size=3, bias=False,
                                   padding=1)
                                   padding=1)
         self.init_weight()
         self.init_weight()
 
 
@@ -427,15 +427,15 @@ class ShelfNetBase(ShelfNetModuleBase):
     ShelfNetBase - ShelfNet Base Generic Architecture
     ShelfNetBase - ShelfNet Base Generic Architecture
     """
     """
 
 
-    def __init__(self, backbone: ShelfResNetBackBone, planes: int, layers: int, classes_num: int = 21,
+    def __init__(self, backbone: ShelfResNetBackBone, planes: int, layers: int, num_classes: int = 21,
                  image_size: int = 512,
                  image_size: int = 512,
                  net_output_mid_channels_num: int = 64, arch_params: HpmStruct = None):
                  net_output_mid_channels_num: int = 64, arch_params: HpmStruct = None):
-        self.classes_num = arch_params.num_classes if (arch_params and hasattr(arch_params, 'num_classes')) else classes_num
+        self.num_classes = arch_params.num_classes if (arch_params and hasattr(arch_params, 'num_classes')) else num_classes
         self.image_size = arch_params.image_size if (arch_params and hasattr(arch_params, 'image_size')) else image_size
         self.image_size = arch_params.image_size if (arch_params and hasattr(arch_params, 'image_size')) else image_size
 
 
         super().__init__()
         super().__init__()
         self.net_output_mid_channels_num = net_output_mid_channels_num
         self.net_output_mid_channels_num = net_output_mid_channels_num
-        self.backbone = backbone(self.classes_num)
+        self.backbone = backbone(self.num_classes)
         self.layers = layers
         self.layers = layers
         self.planes = planes
         self.planes = planes
 
 
@@ -476,9 +476,9 @@ class ShelfNetHW(ShelfNetBase):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         self.ladder = LadderBlockHW(planes=self.net_output_mid_channels_num, layers=self.layers)
         self.ladder = LadderBlockHW(planes=self.net_output_mid_channels_num, layers=self.layers)
         self.decoder = DecoderHW(planes=self.net_output_mid_channels_num, layers=self.layers)
         self.decoder = DecoderHW(planes=self.net_output_mid_channels_num, layers=self.layers)
-        self.se_layer = nn.Linear(self.net_output_mid_channels_num * 2 ** 3, self.classes_num)
-        self.aux_head = FCNHead(1024, self.classes_num)
-        self.final = nn.Conv2d(self.net_output_mid_channels_num, self.classes_num, 1)
+        self.se_layer = nn.Linear(self.net_output_mid_channels_num * 2 ** 3, self.num_classes)
+        self.aux_head = FCNHead(1024, self.num_classes)
+        self.final = nn.Conv2d(self.net_output_mid_channels_num, self.num_classes, 1)
 
 
         # THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
         # THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
         net_out_planes = self.planes
         net_out_planes = self.planes
@@ -638,7 +638,7 @@ class ShelfNet18_LW(ShelfNetLW):
             mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
             mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
 
 
             self.net_output_list.append(
             self.net_output_list.append(
-                NetOutput(out_planes, mid_channels_num, self.classes_num))
+                NetOutput(out_planes, mid_channels_num, self.num_classes))
 
 
             self.conv_out_list.append(
             self.conv_out_list.append(
                 ConvBNReLU(out_planes * 2, out_planes, ks=1, stride=1, padding=0)
                 ConvBNReLU(out_planes * 2, out_planes, ks=1, stride=1, padding=0)
@@ -657,7 +657,7 @@ class ShelfNet34_LW(ShelfNetLW):
             # IF IT'S THE FIRST LAYER THAN THE MID-CHANNELS NUM IS ACTUALLY self.planes
             # IF IT'S THE FIRST LAYER THAN THE MID-CHANNELS NUM IS ACTUALLY self.planes
             mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
             mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
             self.net_output_list.append(
             self.net_output_list.append(
-                NetOutput(net_out_planes, mid_channels_num, self.classes_num))
+                NetOutput(net_out_planes, mid_channels_num, self.num_classes))
 
 
             net_out_planes *= 2
             net_out_planes *= 2
 
 
Discard
@@ -12,19 +12,19 @@ class TestShelfNet(unittest.TestCase):
         """
         """
         dummy_input = torch.randn(1, 3, 512, 512)
         dummy_input = torch.randn(1, 3, 512, 512)
 
 
-        shelfnet18_model = ShelfNet18_LW(classes_num=21)
+        shelfnet18_model = ShelfNet18_LW(num_classes=21)
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         self.assertTrue(shelfnet18_model.conv_out_list)
         self.assertTrue(shelfnet18_model.conv_out_list)
 
 
-        shelfnet34_model = ShelfNet34_LW(classes_num=21)
+        shelfnet34_model = ShelfNet34_LW(num_classes=21)
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         self.assertTrue(shelfnet34_model.conv_out_list)
         self.assertTrue(shelfnet34_model.conv_out_list)
 
 
-        shelfnet50_model = ShelfNet50(classes_num=21)
+        shelfnet50_model = ShelfNet50(num_classes=21)
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         self.assertTrue(shelfnet50_model.conv_out_list)
         self.assertTrue(shelfnet50_model.conv_out_list)
 
 
-        shelfnet101_model = ShelfNet101(classes_num=21)
+        shelfnet101_model = ShelfNet101(num_classes=21)
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         # VALIDATES INNER CONV LIST WAS INITIALIZED CORRECTLY
         self.assertTrue(shelfnet101_model.conv_out_list)
         self.assertTrue(shelfnet101_model.conv_out_list)
 
 
Discard