Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#399 Feature/sg 326 replace function with class in architectures

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-326-replace_function_with_class_in_architectures
@@ -1,13 +1,13 @@
-from super_gradients.training.models import ResNeXt50, ResNeXt101, googlenet_v1
+from super_gradients.training.models import ResNeXt50, ResNeXt101, GoogleNetV1
 from super_gradients.training.models.classification_models import repvgg, efficientnet, densenet, resnet, regnet
 from super_gradients.training.models.classification_models import repvgg, efficientnet, densenet, resnet, regnet
-from super_gradients.training.models.classification_models.mobilenetv2 import mobile_net_v2, mobile_net_v2_135, \
-    custom_mobile_net_v2
+from super_gradients.training.models.classification_models.mobilenetv2 import MobileNetV2Base, MobileNetV2_135, \
+    CustomMobileNetV2
 from super_gradients.training.models.classification_models.mobilenetv3 import mobilenetv3_large, mobilenetv3_small, \
 from super_gradients.training.models.classification_models.mobilenetv3 import mobilenetv3_large, mobilenetv3_small, \
     mobilenetv3_custom
     mobilenetv3_custom
 from super_gradients.training.models.classification_models.shufflenetv2 import ShufflenetV2_x0_5, ShufflenetV2_x1_0, \
 from super_gradients.training.models.classification_models.shufflenetv2 import ShufflenetV2_x0_5, ShufflenetV2_x1_0, \
     ShufflenetV2_x1_5, \
     ShufflenetV2_x1_5, \
     ShufflenetV2_x2_0, CustomizedShuffleNetV2
     ShufflenetV2_x2_0, CustomizedShuffleNetV2
-from super_gradients.training.models.classification_models.vit import vit_base, vit_large, vit_huge
+from super_gradients.training.models.classification_models.vit import ViTBase, ViTLarge, ViTHuge
 from super_gradients.training.models.detection_models.csp_darknet53 import CSPDarknet53
 from super_gradients.training.models.detection_models.csp_darknet53 import CSPDarknet53
 from super_gradients.training.models.detection_models.darknet53 import Darknet53
 from super_gradients.training.models.detection_models.darknet53 import Darknet53
 from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
 from super_gradients.training.models.detection_models.ssd import SSDMobileNetV1, SSDLiteMobileNetV2
@@ -20,7 +20,7 @@ from super_gradients.training.models.segmentation_models.stdc import STDC1Classi
     STDC1Seg, STDC2Seg
     STDC1Seg, STDC2Seg
 
 
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
-from super_gradients.training.models.classification_models.beit import beit_base_patch16_224, beit_large_patch16_224
+from super_gradients.training.models.classification_models.beit import BeitBasePatch16_224, BeitLargePatch16_224
 from super_gradients.training.models.segmentation_models.ppliteseg import PPLiteSegT, PPLiteSegB
 from super_gradients.training.models.segmentation_models.ppliteseg import PPLiteSegT, PPLiteSegB
 
 
 
 
@@ -135,17 +135,17 @@ ARCHITECTURES = {ModelNames.RESNET18: resnet.ResNet18,
                  ModelNames.CUSTOM_RESNET50: resnet.CustomizedResnet50,
                  ModelNames.CUSTOM_RESNET50: resnet.CustomizedResnet50,
                  ModelNames.CUSTOM_RESNET_CIFAR: resnet.CustomizedResnetCifar,
                  ModelNames.CUSTOM_RESNET_CIFAR: resnet.CustomizedResnetCifar,
                  ModelNames.CUSTOM_RESNET50_CIFAR: resnet.CustomizedResnet50Cifar,
                  ModelNames.CUSTOM_RESNET50_CIFAR: resnet.CustomizedResnet50Cifar,
-                 ModelNames.MOBILENET_V2: mobile_net_v2,
-                 ModelNames.MOBILE_NET_V2_135: mobile_net_v2_135,
-                 ModelNames.CUSTOM_MOBILENET_V2: custom_mobile_net_v2,
+                 ModelNames.MOBILENET_V2: MobileNetV2Base,
+                 ModelNames.MOBILE_NET_V2_135: MobileNetV2_135,
+                 ModelNames.CUSTOM_MOBILENET_V2: CustomMobileNetV2,
                  ModelNames.MOBILENET_V3_LARGE: mobilenetv3_large,
                  ModelNames.MOBILENET_V3_LARGE: mobilenetv3_large,
                  ModelNames.MOBILENET_V3_SMALL: mobilenetv3_small,
                  ModelNames.MOBILENET_V3_SMALL: mobilenetv3_small,
                  ModelNames.MOBILENET_V3_CUSTOM: mobilenetv3_custom,
                  ModelNames.MOBILENET_V3_CUSTOM: mobilenetv3_custom,
                  ModelNames.CUSTOM_DENSENET: densenet.CustomizedDensnet,
                  ModelNames.CUSTOM_DENSENET: densenet.CustomizedDensnet,
-                 ModelNames.DENSENET121: densenet.densenet121,
-                 ModelNames.DENSENET161: densenet.densenet161,
-                 ModelNames.DENSENET169: densenet.densenet169,
-                 ModelNames.DENSENET201: densenet.densenet201,
+                 ModelNames.DENSENET121: densenet.DenseNet121,
+                 ModelNames.DENSENET161: densenet.DenseNet161,
+                 ModelNames.DENSENET169: densenet.DenseNet169,
+                 ModelNames.DENSENET201: densenet.DenseNet201,
                  ModelNames.SHELFNET18_LW: ShelfNet18_LW,
                  ModelNames.SHELFNET18_LW: ShelfNet18_LW,
                  ModelNames.SHELFNET34_LW: ShelfNet34_LW,
                  ModelNames.SHELFNET34_LW: ShelfNet34_LW,
                  ModelNames.SHELFNET50_3343: ShelfNet503343,
                  ModelNames.SHELFNET50_3343: ShelfNet503343,
@@ -160,17 +160,17 @@ ARCHITECTURES = {ModelNames.RESNET18: resnet.ResNet18,
                  ModelNames.CSP_DARKNET53: CSPDarknet53,
                  ModelNames.CSP_DARKNET53: CSPDarknet53,
                  ModelNames.RESNEXT50: ResNeXt50,
                  ModelNames.RESNEXT50: ResNeXt50,
                  ModelNames.RESNEXT101: ResNeXt101,
                  ModelNames.RESNEXT101: ResNeXt101,
-                 ModelNames.GOOGLENET_V1: googlenet_v1,
-                 ModelNames.EFFICIENTNET_B0: efficientnet.b0,
-                 ModelNames.EFFICIENTNET_B1: efficientnet.b1,
-                 ModelNames.EFFICIENTNET_B2: efficientnet.b2,
-                 ModelNames.EFFICIENTNET_B3: efficientnet.b3,
-                 ModelNames.EFFICIENTNET_B4: efficientnet.b4,
-                 ModelNames.EFFICIENTNET_B5: efficientnet.b5,
-                 ModelNames.EFFICIENTNET_B6: efficientnet.b6,
-                 ModelNames.EFFICIENTNET_B7: efficientnet.b7,
-                 ModelNames.EFFICIENTNET_B8: efficientnet.b8,
-                 ModelNames.EFFICIENTNET_L2: efficientnet.l2,
+                 ModelNames.GOOGLENET_V1: GoogleNetV1,
+                 ModelNames.EFFICIENTNET_B0: efficientnet.EfficientNetB0,
+                 ModelNames.EFFICIENTNET_B1: efficientnet.EfficientNetB1,
+                 ModelNames.EFFICIENTNET_B2: efficientnet.EfficientNetB2,
+                 ModelNames.EFFICIENTNET_B3: efficientnet.EfficientNetB3,
+                 ModelNames.EFFICIENTNET_B4: efficientnet.EfficientNetB4,
+                 ModelNames.EFFICIENTNET_B5: efficientnet.EfficientNetB5,
+                 ModelNames.EFFICIENTNET_B6: efficientnet.EfficientNetB6,
+                 ModelNames.EFFICIENTNET_B7: efficientnet.EfficientNetB7,
+                 ModelNames.EFFICIENTNET_B8: efficientnet.EfficientNetB8,
+                 ModelNames.EFFICIENTNET_L2: efficientnet.EfficientNetL2,
                  ModelNames.CUSTOMIZEDEFFICIENTNET: efficientnet.CustomizedEfficientnet,
                  ModelNames.CUSTOMIZEDEFFICIENTNET: efficientnet.CustomizedEfficientnet,
                  ModelNames.REGNETY200: regnet.RegNetY200,
                  ModelNames.REGNETY200: regnet.RegNetY200,
                  ModelNames.REGNETY400: regnet.RegNetY400,
                  ModelNames.REGNETY400: regnet.RegNetY400,
@@ -209,11 +209,11 @@ ARCHITECTURES = {ModelNames.RESNET18: resnet.ResNet18,
                  ModelNames.STDC2_SEG75: STDC2Seg,
                  ModelNames.STDC2_SEG75: STDC2Seg,
                  ModelNames.REGSEG48: RegSeg48,
                  ModelNames.REGSEG48: RegSeg48,
                  ModelNames.KD_MODULE: KDModule,
                  ModelNames.KD_MODULE: KDModule,
-                 ModelNames.VIT_BASE: vit_base,
-                 ModelNames.VIT_LARGE: vit_large,
-                 ModelNames.VIT_HUGE: vit_huge,
-                 ModelNames.BEIT_BASE_PATCH16_224: beit_base_patch16_224,
-                 ModelNames.BEIT_LARGE_PATCH16_224: beit_large_patch16_224,
+                 ModelNames.VIT_BASE: ViTBase,
+                 ModelNames.VIT_LARGE: ViTLarge,
+                 ModelNames.VIT_HUGE: ViTHuge,
+                 ModelNames.BEIT_BASE_PATCH16_224: BeitBasePatch16_224,
+                 ModelNames.BEIT_LARGE_PATCH16_224: BeitLargePatch16_224,
                  ModelNames.PP_LITE_T_SEG: PPLiteSegT,
                  ModelNames.PP_LITE_T_SEG: PPLiteSegT,
                  ModelNames.PP_LITE_T_SEG50: PPLiteSegT,
                  ModelNames.PP_LITE_T_SEG50: PPLiteSegT,
                  ModelNames.PP_LITE_T_SEG75: PPLiteSegT,
                  ModelNames.PP_LITE_T_SEG75: PPLiteSegT,
Discard
@@ -417,31 +417,31 @@ class Beit(SgModule):
             self.head = nn.Linear(self.head.in_features, new_num_classes)
             self.head = nn.Linear(self.head.in_features, new_num_classes)
 
 
 
 
-def beit_base_patch16_224(arch_params: HpmStruct):
-    model_kwargs = HpmStruct(patch_size=(16, 16),
-                             embed_dim=768,
-                             depth=12,
-                             num_heads=12,
-                             mlp_ratio=4,
-                             use_abs_pos_emb=False,
-                             use_rel_pos_bias=True,
-                             init_values=0.1)
-    model_kwargs.override(**arch_params.to_dict())
-    model = Beit(**model_kwargs.to_dict())
-    return model
-
-
-def beit_large_patch16_224(arch_params: HpmStruct):
-    model_kwargs = HpmStruct(patch_size=(16, 16),
-                             embed_dim=1024,
-                             depth=24,
-                             num_heads=16,
-                             mlp_ratio=4,
-                             qkv_bias=True,
-                             use_abs_pos_emb=False,
-                             use_rel_pos_bias=True,
-                             init_values=1e-5)
-
-    model_kwargs.override(**arch_params.to_dict())
-    model = Beit(**model_kwargs.to_dict())
-    return model
+class BeitBasePatch16_224(Beit):
+    def __init__(self, arch_params: HpmStruct):
+        model_kwargs = HpmStruct(patch_size=(16, 16),
+                                 embed_dim=768,
+                                 depth=12,
+                                 num_heads=12,
+                                 mlp_ratio=4,
+                                 use_abs_pos_emb=False,
+                                 use_rel_pos_bias=True,
+                                 init_values=0.1)
+        model_kwargs.override(**arch_params.to_dict())
+        super(BeitBasePatch16_224, self).__init__(**model_kwargs.to_dict())
+
+
+class BeitLargePatch16_224(Beit):
+    def __init__(self, arch_params: HpmStruct):
+        model_kwargs = HpmStruct(patch_size=(16, 16),
+                                 embed_dim=1024,
+                                 depth=24,
+                                 num_heads=16,
+                                 mlp_ratio=4,
+                                 qkv_bias=True,
+                                 use_abs_pos_emb=False,
+                                 use_rel_pos_bias=True,
+                                 init_values=1e-5)
+
+        model_kwargs.override(**arch_params.to_dict())
+        super(BeitLargePatch16_224, self).__init__(**model_kwargs.to_dict())
Discard
@@ -132,27 +132,31 @@ class DenseNet(SgModule):
         return out
         return out
 
 
 
 
-def CustomizedDensnet(arch_params):
-    return DenseNet(growth_rate=arch_params.growth_rate if hasattr(arch_params, "growth_rate") else 32,
-                    structure=arch_params.structure if hasattr(arch_params, "structure") else [6, 12, 24, 16],
-                    num_init_features=arch_params.num_init_features if hasattr(arch_params,
-                                                                               "num_init_features") else 64,
-                    bn_size=arch_params.bn_size if hasattr(arch_params, "bn_size") else 4,
-                    drop_rate=arch_params.drop_rate if hasattr(arch_params, "drop_rate") else 0,
-                    num_classes=arch_params.num_classes)
+class CustomizedDensnet(DenseNet):
+    def __init__(self, arch_params):
+        super().__init__(growth_rate=arch_params.growth_rate if hasattr(arch_params, "growth_rate") else 32,
+                         structure=arch_params.structure if hasattr(arch_params, "structure") else [6, 12, 24, 16],
+                         num_init_features=arch_params.num_init_features if hasattr(arch_params, "num_init_features") else 64,
+                         bn_size=arch_params.bn_size if hasattr(arch_params, "bn_size") else 4,
+                         drop_rate=arch_params.drop_rate if hasattr(arch_params, "drop_rate") else 0,
+                         num_classes=arch_params.num_classes)
 
 
 
 
-def densenet121(arch_params):
-    return DenseNet(32, [6, 12, 24, 16], 64, 4, 0, arch_params.num_classes)
+class DenseNet121(DenseNet):
+    def __init__(self, arch_params):
+        super().__init__(32, [6, 12, 24, 16], 64, 4, 0, arch_params.num_classes)
 
 
 
 
-def densenet161(arch_params):
-    return DenseNet(48, [6, 12, 36, 24], 96, 4, 0, arch_params.num_classes)
+class DenseNet161(DenseNet):
+    def __init__(self, arch_params):
+        super().__init__(48, [6, 12, 36, 24], 96, 4, 0, arch_params.num_classes)
 
 
 
 
-def densenet169(arch_params):
-    return DenseNet(32, [6, 12, 32, 32], 64, 4, 0, arch_params.num_classes)
+class DenseNet169(DenseNet):
+    def __init__(self, arch_params):
+        super().__init__(32, [6, 12, 32, 32], 64, 4, 0, arch_params.num_classes)
 
 
 
 
-def densenet201(arch_params):
-    return DenseNet(32, [6, 12, 48, 32], 64, 4, 0, arch_params.num_classes)
+class DenseNet201(DenseNet):
+    def __init__(self, arch_params):
+        super().__init__(32, [6, 12, 48, 32], 64, 4, 0, arch_params.num_classes)
Discard
@@ -5,8 +5,8 @@ Pre-trained checkpoints converted to Deci's code base with the reported accuracy
 """
 """
 #######################################################################################################################
 #######################################################################################################################
 #   1. Since each net expects a specific image size, make sure to build the dataset with the correct image size:
 #   1. Since each net expects a specific image size, make sure to build the dataset with the correct image size:
-#         b0 - (224, 256), b1 - (240, 274), b2 - (260, 298), b3 - (300, 342), b4 - (380, 434),
-#         b5 - (456, 520), b6 - (528, 602), b7 - (600, 684), b8 - (672, 768), l2 - (800, 914)
+#         EfficientNetB0 - (224, 256), EfficientNetB1 - (240, 274), EfficientNetB2 - (260, 298), EfficientNetB3 - (300, 342), EfficientNetB4 - (380, 434),
+#         EfficientNetB5 - (456, 520), EfficientNetB6 - (528, 602), EfficientNetB7 - (600, 684), EfficientNetB8 - (672, 768), EfficientNetL2 - (800, 914)
 #         You should build the DataSetInterface with the following dictionary:
 #         You should build the DataSetInterface with the following dictionary:
 #           ImageNetDatasetInterface(dataset_params = {'crop': 260, 'resize':  298})
 #           ImageNetDatasetInterface(dataset_params = {'crop': 260, 'resize':  298})
 #   2. Pre-trained ImageNet models can be found in S3://deci-model-repository-research/efficientnet_b#/ckpt_best.pth
 #   2. Pre-trained ImageNet models can be found in S3://deci-model-repository-research/efficientnet_b#/ckpt_best.pth
@@ -551,19 +551,10 @@ class EfficientNet(SgModule):
         super().load_state_dict(pretrained_model_weights_dict, strict)
         super().load_state_dict(pretrained_model_weights_dict, strict)
 
 
 
 
-def build_efficientnet(width, depth, res, dropout, arch_params):
-    """
-
-    :param width:
-    :param depth:
-    :param res:
-    :param dropout:
-    :param arch_params:
-    :return:
-    """
+def get_efficientnet_params(width: float, depth: float, res: float, dropout: float, arch_params: HpmStruct):
     print(f"\nNOTICE: \nachieving EfficientNet\'s reported accuracy requires specific image resolution."
     print(f"\nNOTICE: \nachieving EfficientNet\'s reported accuracy requires specific image resolution."
           f"\nPlease verify image size is {res}x{res} for this specific EfficientNet configuration\n")
           f"\nPlease verify image size is {res}x{res} for this specific EfficientNet configuration\n")
-    # Blocks args for the whole model(efficientnet-b0 by default)
+    # Blocks args for the whole model(efficientnet-EfficientNetB0 by default)
     # It will be modified in the construction of EfficientNet Class according to model
     # It will be modified in the construction of EfficientNet Class according to model
     blocks_args = BlockDecoder.decode(['r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
     blocks_args = BlockDecoder.decode(['r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
                                        'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
                                        'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
@@ -576,49 +567,74 @@ def build_efficientnet(width, depth, res, dropout, arch_params):
                                    "depth_divisor": 8, "min_depth": None, "backbone_mode": False})
                                    "depth_divisor": 8, "min_depth": None, "backbone_mode": False})
     # Update arch_params
     # Update arch_params
     arch_params_new.override(**arch_params.to_dict())
     arch_params_new.override(**arch_params.to_dict())
-    return EfficientNet(blocks_args, arch_params_new)
+    return blocks_args, arch_params_new
 
 
 
 
-def b0(arch_params):
-    return build_efficientnet(1.0, 1.0, 224, 0.2, arch_params)
+class EfficientNetB0(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=1.0, depth=1.0, res=224, dropout=0.2, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b1(arch_params):
-    return build_efficientnet(1.0, 1.1, 240, 0.2, arch_params)
+class EfficientNetB1(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=1.0, depth=1.1, res=240, dropout=0.2, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b2(arch_params):
-    return build_efficientnet(1.1, 1.2, 260, 0.3, arch_params)
+class EfficientNetB2(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=1.1, depth=1.2, res=260, dropout=0.3, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b3(arch_params):
-    return build_efficientnet(1.2, 1.4, 300, 0.3, arch_params)
+class EfficientNetB3(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=1.2, depth=1.4, res=300, dropout=0.3, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b4(arch_params):
-    return build_efficientnet(1.4, 1.8, 380, 0.4, arch_params)
+class EfficientNetB4(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=1.4, depth=1.8, res=380, dropout=0.4, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b5(arch_params):
-    return build_efficientnet(1.6, 2.2, 456, 0.4, arch_params)
+class EfficientNetB5(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=1.6, depth=2.2, res=456, dropout=0.4, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b6(arch_params):
-    return build_efficientnet(1.8, 2.6, 528, 0.5, arch_params)
+class EfficientNetB6(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=1.8, depth=2.6, res=528, dropout=0.5, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b7(arch_params):
-    return build_efficientnet(2.0, 3.1, 600, 0.5, arch_params)
+class EfficientNetB7(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=2.0, depth=3.1, res=600, dropout=0.5, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def b8(arch_params):
-    return build_efficientnet(2.2, 3.6, 672, 0.5, arch_params)
+class EfficientNetB8(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=2.2, depth=3.6, res=672, dropout=0.5, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def l2(arch_params):
-    return build_efficientnet(4.3, 5.3, 800, 0.5, arch_params)
+class EfficientNetL2(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=4.3, depth=5.3, res=800, dropout=0.5, arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
 
 
 
 
-def CustomizedEfficientnet(arch_params):
-    return build_efficientnet(arch_params.width_coefficient, arch_params.depth_coefficient, arch_params.res,
-                              arch_params.dropout_rate, arch_params)
+class CustomizedEfficientnet(EfficientNet):
+    def __init__(self, arch_params):
+        blocks_args, arch_params = get_efficientnet_params(width=arch_params.width_coefficient,
+                                                           depth=arch_params.depth_coefficient,
+                                                           res=arch_params.res,
+                                                           dropout=arch_params.dropout_rate,
+                                                           arch_params=arch_params)
+        super().__init__(blocks_args=blocks_args, arch_params=arch_params)
Discard
@@ -246,5 +246,6 @@ class BasicConv2d(nn.Module):
         return x
         return x
 
 
 
 
-def googlenet_v1(arch_params):
-    return GoogLeNet(aux_logits=False, num_classes=arch_params.num_classes, dropout=arch_params.dropout)
+class GoogleNetV1(GoogLeNet):
+    def __init__(self, arch_params):
+        super(GoogleNetV1, self).__init__(aux_logits=False, num_classes=arch_params.num_classes, dropout=arch_params.dropout)
Discard
@@ -188,52 +188,50 @@ class MobileNetV2(MobileNetBase):
                 m.bias.data.zero_()
                 m.bias.data.zero_()
 
 
 
 
-def mobile_net_v2(arch_params):
-    """
-    :param arch_params: HpmStruct
-        must contain: 'num_classes': int
-    :return: MobileNetV2: nn.Module
-    """
-    return MobileNetV2(
-        num_classes=arch_params.num_classes,
-        width_mult=1.0,
-        structure=None,
-        dropout=get_param(arch_params, "dropout", 0.0),
-        in_channels=get_param(arch_params, "in_channels", 3),
-    )
-
-
-def mobile_net_v2_135(arch_params):
-    """
-    This Model achieves 75.73% on Imagenet - similar to Resnet50
-    :param arch_params: HpmStruct
-        must contain: 'num_classes': int
-    :return: MobileNetV2: nn.Module
-    """
-
-    return MobileNetV2(
-        num_classes=arch_params.num_classes,
-        width_mult=1.35,
-        structure=None,
-        dropout=get_param(arch_params, "dropout", 0.0),
-        in_channels=get_param(arch_params, "in_channels", 3),
-    )
-
-
-def custom_mobile_net_v2(arch_params):
-    """
-    :param arch_params: HpmStruct
-        must contain:
-            'num_classes': int
-            'width_mult': float
-            'structure' : list. specify the mobilenetv2 architecture
-    :return: MobileNetV2: nn.Module
-    """
-
-    return MobileNetV2(
-        num_classes=arch_params.num_classes,
-        width_mult=arch_params.width_mult,
-        structure=arch_params.structure,
-        dropout=get_param(arch_params, "dropout", 0.0),
-        in_channels=get_param(arch_params, "in_channels", 3),
-    )
+class MobileNetV2Base(MobileNetV2):
+    def __init__(self, arch_params):
+        """
+        :param arch_params: HpmStruct
+            must contain: 'num_classes': int
+        """
+        super().__init__(
+            num_classes=arch_params.num_classes,
+            width_mult=1.0,
+            structure=None,
+            dropout=get_param(arch_params, "dropout", 0.0),
+            in_channels=get_param(arch_params, "in_channels", 3),
+        )
+
+
+class MobileNetV2_135(MobileNetV2):
+    def __init__(self, arch_params):
+        """
+        This Model achieves–≠ 75.73% on Imagenet - similar to Resnet50
+        :param arch_params: HpmStruct
+            must contain: 'num_classes': int
+        """
+        super().__init__(
+            num_classes=arch_params.num_classes,
+            width_mult=1.35,
+            structure=None,
+            dropout=get_param(arch_params, "dropout", 0.0),
+            in_channels=get_param(arch_params, "in_channels", 3),
+        )
+
+
+class CustomMobileNetV2(MobileNetV2):
+    def __init__(self, arch_params):
+        """
+        :param arch_params:–≠ HpmStruct
+            must contain:
+                'num_classes': int
+                'width_mult': float
+                'structure' : list. specify the mobilenetv2 architecture
+        """
+        super().__init__(
+            num_classes=arch_params.num_classes,
+            width_mult=arch_params.width_mult,
+            structure=arch_params.structure,
+            dropout=get_param(arch_params, "dropout", 0.0),
+            in_channels=get_param(arch_params, "in_channels", 3),
+        )
Discard
@@ -182,61 +182,63 @@ class MobileNetV3(MobileNetBase):
                 m.bias.data.zero_()
                 m.bias.data.zero_()
 
 
 
 
-def mobilenetv3_large(arch_params):
+class mobilenetv3_large(MobileNetV3):
     """
     """
     Constructs a MobileNetV3-Large model
     Constructs a MobileNetV3-Large model
     """
     """
-    width_mult = arch_params.width_mult if hasattr(arch_params, 'width_mult') else 1.
-    cfgs = [
-        # k, t, c, SE, HS, s
-        [3, 1, 16, 0, 0, 1],
-        [3, 4, 24, 0, 0, 2],
-        [3, 3, 24, 0, 0, 1],
-        [5, 3, 40, 1, 0, 2],
-        [5, 3, 40, 1, 0, 1],
-        [5, 3, 40, 1, 0, 1],
-        [3, 6, 80, 0, 1, 2],
-        [3, 2.5, 80, 0, 1, 1],
-        [3, 2.3, 80, 0, 1, 1],
-        [3, 2.3, 80, 0, 1, 1],
-        [3, 6, 112, 1, 1, 1],
-        [3, 6, 112, 1, 1, 1],
-        [5, 6, 160, 1, 1, 2],
-        [5, 6, 160, 1, 1, 1],
-        [5, 6, 160, 1, 1, 1]
-    ]
-    return MobileNetV3(cfgs, mode='large', num_classes=arch_params.num_classes, width_mult=width_mult,
-                       in_channels=get_param(arch_params, "in_channels", 3))
-
-
-def mobilenetv3_small(arch_params):
+    def __init__(self, arch_params):
+        width_mult = arch_params.width_mult if hasattr(arch_params, 'width_mult') else 1.
+        cfgs = [
+            # k, t, c, SE, HS, s
+            [3, 1, 16, 0, 0, 1],
+            [3, 4, 24, 0, 0, 2],
+            [3, 3, 24, 0, 0, 1],
+            [5, 3, 40, 1, 0, 2],
+            [5, 3, 40, 1, 0, 1],
+            [5, 3, 40, 1, 0, 1],
+            [3, 6, 80, 0, 1, 2],
+            [3, 2.5, 80, 0, 1, 1],
+            [3, 2.3, 80, 0, 1, 1],
+            [3, 2.3, 80, 0, 1, 1],
+            [3, 6, 112, 1, 1, 1],
+            [3, 6, 112, 1, 1, 1],
+            [5, 6, 160, 1, 1, 2],
+            [5, 6, 160, 1, 1, 1],
+            [5, 6, 160, 1, 1, 1]
+        ]
+        super().__init__(cfgs, mode='large', num_classes=arch_params.num_classes, width_mult=width_mult,
+                         in_channels=get_param(arch_params, "in_channels", 3))
+
+
+class mobilenetv3_small(MobileNetV3):
     """
     """
     Constructs a MobileNetV3-Small model
     Constructs a MobileNetV3-Small model
     """
     """
-    width_mult = arch_params.width_mult if hasattr(arch_params, 'width_mult') else 1.
-    cfgs = [
-        # k, t, c, SE, HS, s
-        [3, 1, 16, 1, 0, 2],
-        [3, 4.5, 24, 0, 0, 2],
-        [3, 3.67, 24, 0, 0, 1],
-        [5, 4, 40, 1, 1, 2],
-        [5, 6, 40, 1, 1, 1],
-        [5, 6, 40, 1, 1, 1],
-        [5, 3, 48, 1, 1, 1],
-        [5, 3, 48, 1, 1, 1],
-        [5, 6, 96, 1, 1, 2],
-        [5, 6, 96, 1, 1, 1],
-        [5, 6, 96, 1, 1, 1],
-    ]
-
-    return MobileNetV3(cfgs, mode='small', num_classes=arch_params.num_classes, width_mult=width_mult,
-                       in_channels=get_param(arch_params, "in_channels", 3))
-
-
-def mobilenetv3_custom(arch_params):
+    def __init__(self, arch_params):
+        width_mult = arch_params.width_mult if hasattr(arch_params, 'width_mult') else 1.
+        cfgs = [
+            # k, t, c, SE, HS, s
+            [3, 1, 16, 1, 0, 2],
+            [3, 4.5, 24, 0, 0, 2],
+            [3, 3.67, 24, 0, 0, 1],
+            [5, 4, 40, 1, 1, 2],
+            [5, 6, 40, 1, 1, 1],
+            [5, 6, 40, 1, 1, 1],
+            [5, 3, 48, 1, 1, 1],
+            [5, 3, 48, 1, 1, 1],
+            [5, 6, 96, 1, 1, 2],
+            [5, 6, 96, 1, 1, 1],
+            [5, 6, 96, 1, 1, 1],
+        ]
+        super().__init__(cfgs, mode='small', num_classes=arch_params.num_classes, width_mult=width_mult,
+                         in_channels=get_param(arch_params, "in_channels", 3))
+
+
+class mobilenetv3_custom(MobileNetV3):
     """
     """
     Constructs a MobileNetV3-Customized model
     Constructs a MobileNetV3-Customized model
     """
     """
-    return MobileNetV3(cfgs=arch_params.structure, mode=arch_params.mode, num_classes=arch_params.num_classes,
-                       width_mult=arch_params.width_mult,
-                       in_channels=get_param(arch_params, "in_channels", 3))
+    def __init__(self, arch_params):
+        super().__init__(cfgs=arch_params.structure, mode=arch_params.mode, num_classes=arch_params.num_classes,
+                         width_mult=arch_params.width_mult,
+                         in_channels=get_param(arch_params, "in_channels", 3))
Discard
@@ -238,65 +238,76 @@ class ResNet(SgModule):
             self.linear = nn.Linear(width_multiplier(512, self.width_mult) * self.expansion, new_num_classes)
             self.linear = nn.Linear(width_multiplier(512, self.width_mult) * self.expansion, new_num_classes)
 
 
 
 
-def ResNet18(arch_params, num_classes=None):
-    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False))
+class ResNet18(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False))
 
 
 
 
-def ResNet18Cifar(arch_params, num_classes=None):
-    return CifarResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)
+class ResNet18Cifar(CifarResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)
 
 
 
 
-def ResNet34(arch_params, num_classes=None):
-    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False))
+class ResNet34(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(BasicBlock, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False))
 
 
 
 
-def ResNet50(arch_params, num_classes=None):
-    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
+class ResNet50(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(Bottleneck, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
 
 
 
 
-def ResNet50_3343(arch_params, num_classes=None):
-    return ResNet(Bottleneck, [3, 3, 4, 3], num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
+class ResNet50_3343(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(Bottleneck, [3, 3, 4, 3], num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
 
 
 
 
-def ResNet101(arch_params, num_classes=None):
-    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
+class ResNet101(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(Bottleneck, [3, 4, 23, 3], num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
 
 
 
 
-def ResNet152(arch_params, num_classes=None):
-    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
+class ResNet152(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(Bottleneck, [3, 8, 36, 3], num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
 
 
 
 
-def CustomizedResnetCifar(arch_params, num_classes=None):
-    return CifarResNet(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
-                       num_classes=num_classes or arch_params.num_classes)
+class CustomizedResnetCifar(CifarResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
+                         num_classes=num_classes or arch_params.num_classes)
 
 
 
 
-def CustomizedResnet50Cifar(arch_params, num_classes=None):
-    return CifarResNet(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
-                       num_classes=num_classes or arch_params.num_classes, expansion=4)
+class CustomizedResnet50Cifar(CifarResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
+                         num_classes=num_classes or arch_params.num_classes, expansion=4)
 
 
 
 
-def CustomizedResnet(arch_params, num_classes=None):
-    return ResNet(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
-                  num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False))
+class CustomizedResnet(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
+                         num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False))
 
 
 
 
-def CustomizedResnet50(arch_params, num_classes=None):
-    return ResNet(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
-                  num_classes=num_classes or arch_params.num_classes,
-                  droppath_prob=get_param(arch_params, "droppath_prob", 0),
-                  backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
+class CustomizedResnet50(ResNet):
+    def __init__(self, arch_params, num_classes=None):
+        super().__init__(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
+                         num_classes=num_classes or arch_params.num_classes,
+                         droppath_prob=get_param(arch_params, "droppath_prob", 0),
+                         backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
Discard
@@ -139,20 +139,23 @@ class ResNeXt(SgModule):
         return out
         return out
 
 
 
 
-def CustomizedResNeXt(arch_params):
-    return ResNeXt(layers=arch_params.structure if hasattr(arch_params, "structure") else [3, 3, 3],
-                   bottleneck_width=arch_params.num_init_features if hasattr(arch_params, "bottleneck_width") else 64,
-                   cardinality=arch_params.bn_size if hasattr(arch_params, "cardinality") else 32,
-                   num_classes=arch_params.num_classes,
-                   replace_stride_with_dilation=arch_params.replace_stride_with_dilation if
-                   hasattr(arch_params, "replace_stride_with_dilation") else None)
-
-
-def ResNeXt50(arch_params):
-    return ResNeXt(layers=[3, 4, 6, 3], cardinality=32, bottleneck_width=4,
-                   num_classes=arch_params.num_classes)
-
-
-def ResNeXt101(arch_params):
-    return ResNeXt(layers=[3, 4, 23, 3], cardinality=32, bottleneck_width=8,
-                   num_classes=arch_params.num_classes)
+class CustomizedResNeXt(ResNeXt):
+    def __init__(self, arch_params):
+        super(CustomizedResNeXt, self).__init__(layers=arch_params.structure if hasattr(arch_params, "structure") else [3, 3, 3],
+                                                bottleneck_width=arch_params.num_init_features if hasattr(arch_params, "bottleneck_width") else 64,
+                                                cardinality=arch_params.bn_size if hasattr(arch_params, "cardinality") else 32,
+                                                num_classes=arch_params.num_classes,
+                                                replace_stride_with_dilation=arch_params.replace_stride_with_dilation if
+                                                hasattr(arch_params, "replace_stride_with_dilation") else None)
+
+
+class ResNeXt50(ResNeXt):
+    def __init__(self, arch_params):
+        super(ResNeXt50, self).__init__(layers=[3, 4, 6, 3], cardinality=32, bottleneck_width=4,
+                                        num_classes=arch_params.num_classes)
+
+
+class ResNeXt101(ResNeXt):
+    def __init__(self, arch_params):
+        super(ResNeXt101, self).__init__(layers=[3, 4, 23, 3], cardinality=32, bottleneck_width=8,
+                                         num_classes=arch_params.num_classes)
Discard
@@ -180,34 +180,37 @@ class ViT(SgModule):
             self.head = nn.Linear(self.head.in_features, new_num_classes)
             self.head = nn.Linear(self.head.in_features, new_num_classes)
 
 
 
 
-def vit_base(arch_params, num_classes=None, backbone_mode=None):
-    return ViT(image_size=get_param(arch_params, "image_size", (224, 224)),
-               patch_size=get_param(arch_params, "patch_size", (16, 16)),
-               num_classes=num_classes or arch_params.num_classes,
-               hidden_dim=768, depth=12, heads=12, mlp_dim=3072,
-               in_channels=get_param(arch_params, 'in_channels', 3),
-               dropout_prob=get_param(arch_params, "dropout_prob", 0),
-               emb_dropout_prob=get_param(arch_params, "emb_dropout_prob", 0),
-               backbone_mode=backbone_mode)
-
-
-def vit_large(arch_params, num_classes=None, backbone_mode=None):
-    return ViT(image_size=get_param(arch_params, "image_size", (224, 224)),
-               patch_size=get_param(arch_params, "patch_size", (16, 16)),
-               num_classes=num_classes or arch_params.num_classes,
-               hidden_dim=1024, depth=24, heads=16, mlp_dim=4096,
-               in_channels=get_param(arch_params, 'in_channels', 3),
-               dropout_prob=get_param(arch_params, "dropout_prob", 0),
-               emb_dropout_prob=get_param(arch_params, "emb_dropout_prob", 0),
-               backbone_mode=backbone_mode)
-
-
-def vit_huge(arch_params, num_classes=None, backbone_mode=None):
-    return ViT(image_size=get_param(arch_params, "image_size", (224, 224)),
-               patch_size=get_param(arch_params, "patch_size", (16, 16)),
-               num_classes=num_classes or arch_params.num_classes,
-               hidden_dim=1280, depth=32, heads=16, mlp_dim=5120,
-               in_channels=get_param(arch_params, 'in_channels', 3),
-               dropout_prob=get_param(arch_params, "dropout_prob", 0),
-               emb_dropout_prob=get_param(arch_params, "emb_dropout_prob", 0),
-               backbone_mode=backbone_mode)
+class ViTBase(ViT):
+    def __init__(self, arch_params, num_classes=None, backbone_mode=None):
+        super(ViTBase, self).__init__(image_size=get_param(arch_params, "image_size", (224, 224)),
+                                      patch_size=get_param(arch_params, "patch_size", (16, 16)),
+                                      num_classes=num_classes or arch_params.num_classes,
+                                      hidden_dim=768, depth=12, heads=12, mlp_dim=3072,
+                                      in_channels=get_param(arch_params, 'in_channels', 3),
+                                      dropout_prob=get_param(arch_params, "dropout_prob", 0),
+                                      emb_dropout_prob=get_param(arch_params, "emb_dropout_prob", 0),
+                                      backbone_mode=backbone_mode)
+
+
+class ViTLarge(ViT):
+    def __init__(self, arch_params, num_classes=None, backbone_mode=None):
+        super(ViTLarge, self).__init__(image_size=get_param(arch_params, "image_size", (224, 224)),
+                                       patch_size=get_param(arch_params, "patch_size", (16, 16)),
+                                       num_classes=num_classes or arch_params.num_classes,
+                                       hidden_dim=1024, depth=24, heads=16, mlp_dim=4096,
+                                       in_channels=get_param(arch_params, 'in_channels', 3),
+                                       dropout_prob=get_param(arch_params, "dropout_prob", 0),
+                                       emb_dropout_prob=get_param(arch_params, "emb_dropout_prob", 0),
+                                       backbone_mode=backbone_mode)
+
+
+class ViTHuge(ViT):
+    def __init__(self, arch_params, num_classes=None, backbone_mode=None):
+        super(ViTHuge, self).__init__(image_size=get_param(arch_params, "image_size", (224, 224)),
+                                      patch_size=get_param(arch_params, "patch_size", (16, 16)),
+                                      num_classes=num_classes or arch_params.num_classes,
+                                      hidden_dim=1280, depth=32, heads=16, mlp_dim=5120,
+                                      in_channels=get_param(arch_params, 'in_channels', 3),
+                                      dropout_prob=get_param(arch_params, "dropout_prob", 0),
+                                      emb_dropout_prob=get_param(arch_params, "emb_dropout_prob", 0),
+                                      backbone_mode=backbone_mode)
Discard
@@ -14,7 +14,8 @@ class AllArchitecturesTest(unittest.TestCase):
                                             'threshold': 1,
                                             'threshold': 1,
                                             'sml_net': torch.nn.Identity(),
                                             'sml_net': torch.nn.Identity(),
                                             'big_net': torch.nn.Identity(),
                                             'big_net': torch.nn.Identity(),
-                                            'dropout': 0})
+                                            'dropout': 0,
+                                            'build_residual_branches': True})
 
 
     def test_architecture_is_sg_module(self):
     def test_architecture_is_sg_module(self):
         """
         """
@@ -22,7 +23,7 @@ class AllArchitecturesTest(unittest.TestCase):
         """
         """
         for arch_name in ARCHITECTURES:
         for arch_name in ARCHITECTURES:
             # skip custom constructors to keep all_arch_params as general as a possible
             # skip custom constructors to keep all_arch_params as general as a possible
-            if 'custom' in arch_name.lower() or 'nas' in arch_name.lower():
+            if 'custom' in arch_name.lower() or 'nas' in arch_name.lower() or 'kd' in arch_name.lower():
                 continue
                 continue
             self.assertTrue(isinstance(ARCHITECTURES[arch_name](arch_params=self.all_arch_params), SgModule))
             self.assertTrue(isinstance(ARCHITECTURES[arch_name](arch_params=self.all_arch_params), SgModule))
 
 
Discard