Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#364 build_model refs replaced

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000_remove_build_model
@@ -4,6 +4,8 @@ Deci-lab model export example.
 The main purpose of this code is to demonstrate how to upload the model to the platform, optimize and download it
 The main purpose of this code is to demonstrate how to upload the model to the platform, optimize and download it
  after training is complete, using DeciPlatformCallback.
  after training is complete, using DeciPlatformCallback.
 """
 """
+from super_gradients.training import models
+
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -30,7 +32,7 @@ def main(architecture_name: str):
         ckpt_root_dir=checkpoint_dir,
         ckpt_root_dir=checkpoint_dir,
     )
     )
 
 
-    trainer.build_model(architecture=architecture_name, arch_params={"use_aux_heads": True, "aux_head": True})
+    model = models.get(architecture=architecture_name, arch_params={"use_aux_heads": True, "aux_head": True})
 
 
     # CREATE META-DATA, AND OPTIMIZATION REQUEST FORM FOR DECI PLATFORM POST TRAINING CALLBACK
     # CREATE META-DATA, AND OPTIMIZATION REQUEST FORM FOR DECI PLATFORM POST TRAINING CALLBACK
     model_name = f"{architecture_name}_for_deci_lab_export_example"
     model_name = f"{architecture_name}_for_deci_lab_export_example"
@@ -90,7 +92,7 @@ def main(architecture_name: str):
 
 
     # RUN TRAINING. ONCE ALL EPOCHS ARE DONE THE OPTIMIZED MODEL FILE WILL BE LOCATED IN THE EXPERIMENT'S
     # RUN TRAINING. ONCE ALL EPOCHS ARE DONE THE OPTIMIZED MODEL FILE WILL BE LOCATED IN THE EXPERIMENT'S
     # CHECKPOINT DIRECTORY
     # CHECKPOINT DIRECTORY
-    trainer.train(train_params, train_loader=classification_test_dataloader(),
+    trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(),
                   valid_loader=classification_test_dataloader())
                   valid_loader=classification_test_dataloader())
 
 
 
 
Discard
@@ -24,7 +24,6 @@ trainer = Trainer("regseg48_transfer_learning_old_dice_diff_lrs_head_fixed_50_ep
 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
 # LOADING THE PRETRAINED REGSET, IT WILL CALL IT'S REPLACE_HEAD METHOD AND CHANGE IT'S SEGMENTATION HEAD LAYER ACCORDING
 # TO OUR BINARY SEGMENTATION DATASET
 # TO OUR BINARY SEGMENTATION DATASET
 model = models.get("regseg48", pretrained_weights="cityscapes", num_classes=1)
 model = models.get("regseg48", pretrained_weights="cityscapes", num_classes=1)
-trainer.build_model("regseg48", arch_params={"pretrained_weights": "cityscapes"})
 
 
 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
 # DEFINE TRAINING PARAMS. SEE DOCS FOR THE FULL LIST.
 train_params = {"max_epochs": 50,
 train_params = {"max_epochs": 50,
@@ -52,4 +51,7 @@ train_params = {"max_epochs": 50,
                                                                             last_img_idx_in_batch=4)],
                                                                             last_img_idx_in_batch=4)],
                 }
                 }
 
 
-trainer.train(train_params)
+trainer.train(model=model,
+              training_params=train_params,
+              train_loader=dl_train,
+              valid_loader=dl_val)
Discard
@@ -529,14 +529,14 @@ ALL_DATALOADERS = {"coco2017_train": coco2017_train,
                    }
                    }
 
 
 
 
-def get(name: str, dataset_params: Dict = None, dataloader_params: Dict = None):
+def get(name: str, dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
     """
     """
+    Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.
 
 
-
-    :param name:
-    :param dataset_params:
-    :param dataloader_params:
-    :return:
+    :param name: dataset name in ALL_DATALOADERS.
+    :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.
+    :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
+    :return: initialized DataLoader.
     """
     """
 
 
     if name not in ALL_DATALOADERS.keys():
     if name not in ALL_DATALOADERS.keys():
Discard
@@ -85,73 +85,6 @@ class KDTrainer(Trainer):
                       run_teacher_on_eval=cfg.run_teacher_on_eval,
                       run_teacher_on_eval=cfg.run_teacher_on_eval,
                       train_loader=train_dataloader, valid_loader=val_dataloader)
                       train_loader=train_dataloader, valid_loader=val_dataloader)
 
 
-    def build_model(self,
-                    # noqa: C901 - too complex
-                    architecture: Union[str, KDModule] = 'kd_module',
-                    arch_params={}, checkpoint_params={},
-                    *args, **kwargs):
-        """
-        :param architecture: (Union[str, KDModule]) Defines the network's architecture from models/KD_ARCHITECTURES
-         (default='kd_module')
-
-        :param arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc to be passed to kd
-            architecture class (discarded when architecture is KDModule instance)
-
-        :param checkpoint_params: (dict) A dictionary like object with the following keys/values:
-
-              student_pretrained_weights:   String describing the dataset of the pretrained weights (for example
-              "imagenent") for the student network.
-
-              teacher_pretrained_weights:   String describing the dataset of the pretrained weights (for example
-              "imagenent") for the teacher network.
-
-              teacher_checkpoint_path:    Local path to the teacher's checkpoint. Note that when passing pretrained_weights
-                                   through teacher_arch_params these weights will be overridden by the
-                                   pretrained checkpoint. (default=None)
-
-              load_kd_model_checkpoint:   Whether to load an entire KDModule checkpoint (used to continue KD training)
-               (default=False)
-
-              kd_model_source_ckpt_folder_name: Folder name to load an entire KDModule checkpoint from
-                (self.experiment_name if none is given) to resume KD training (default=None)
-
-              kd_model_external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative
-                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
-                                               load the checkpoint even if the load_checkpoint flag is not provided.
-                                               (deafult=None)
-
-        :keyword student_architecture: (Union[str, SgModule]) Defines the student's architecture from
-            models/ALL_ARCHITECTURES (when str), or directly defined the student network (when SgModule).
-
-        :keyword teacher_architecture: (Union[str, SgModule]) Defines the teacher's architecture from
-            models/ALL_ARCHITECTURES (when str), or directly defined the teacher network (when SgModule).
-
-        :keyword student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student
-            net. (deafult={})
-
-        :keyword teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher
-            net. (deafult={})
-
-        :keyword run_teacher_on_eval: (bool)- whether to run self.teacher at eval mode regardless of self.train(mode)
-
-
-        """
-        kwargs.setdefault("student_architecture", None)
-        kwargs.setdefault("teacher_architecture", None)
-        kwargs.setdefault("student_arch_params", {})
-        kwargs.setdefault("teacher_arch_params", {})
-        kwargs.setdefault("run_teacher_on_eval", False)
-
-        self._validate_args(arch_params, architecture, checkpoint_params, **kwargs)
-
-        self.student_architecture = kwargs.get("student_architecture")
-        self.teacher_architecture = kwargs.get("teacher_architecture")
-        self.student_arch_params = kwargs.get("student_arch_params")
-        self.teacher_arch_params = kwargs.get("teacher_arch_params")
-
-        super(KDTrainer, self).build_model(architecture=architecture, arch_params=arch_params,
-                                           checkpoint_params=checkpoint_params, **kwargs)
-
     def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
     def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
         student_architecture = get_param(kwargs, "student_architecture")
         student_architecture = get_param(kwargs, "student_architecture")
         teacher_architecture = get_param(kwargs, "teacher_architecture")
         teacher_architecture = get_param(kwargs, "teacher_architecture")
Discard
@@ -7,7 +7,6 @@ from typing import Union, Tuple, Mapping, List, Any
 import hydra
 import hydra
 import numpy as np
 import numpy as np
 import torch
 import torch
-from deprecate import deprecated
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 from torch import nn
 from torch import nn
 from torch.utils.data import DataLoader, DistributedSampler
 from torch.utils.data import DataLoader, DistributedSampler
@@ -249,48 +248,6 @@ class Trainer:
         self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
         self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
             HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
             HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
 
 
-    # FIXME - we need to resolve flake8's 'function is too complex' for this function
-    @deprecated(target=None, deprecated_in='2.3.0', remove_in='3.0.0')
-    def build_model(self,  # noqa: C901 - too complex
-                    architecture: Union[str, nn.Module],
-                    arch_params={}, checkpoint_params={}, *args, **kwargs):
-        """
-        :param architecture:               Defines the network's architecture from models/ALL_ARCHITECTURES
-        :param arch_params:                Architecture H.P. e.g.: block, num_blocks, num_classes, etc.
-        :param checkpoint_params:          Dictionary like object with the following key:values:
-
-            load_checkpoint:            Load a pre-trained checkpoint
-            strict_load:                See StrictLoad class documentation for details.
-            source_ckpt_folder_name:    folder name to load the checkpoint from (self.experiment_name if none is given)
-            load_weights_only:          loads only the weight from the checkpoint and zeroize the training params
-            load_backbone:              loads the provided checkpoint to self.net.backbone instead of self.net
-            external_checkpoint_path:   The path to the external checkpoint to be loaded. Can be absolute or relative
-                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to
-                                               load the checkpoint even if the load_checkpoint flag is not provided.
-
-        """
-        if 'num_classes' not in arch_params.keys():
-            if self.classes is None and self.dataset_interface is None:
-                raise Exception('Error', 'Number of classes not defined in arch params and dataset is not defined')
-            else:
-                arch_params['num_classes'] = len(self.classes)
-
-        self.arch_params = core_utils.HpmStruct(**arch_params)
-        self.checkpoint_params = core_utils.HpmStruct(**checkpoint_params)
-
-        self.net = self._instantiate_net(architecture, self.arch_params, checkpoint_params, *args, **kwargs)
-
-        # SAVE THE ARCHITECTURE FOR NEURAL ARCHITECTURE SEARCH
-
-        self.architecture = architecture
-
-        self._net_to_device()
-
-        # SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
-        self.update_param_groups = hasattr(self.net.module, 'update_param_groups')
-
-        self._load_checkpoint_to_model()
-
     def _set_ckpt_loading_attributes(self):
     def _set_ckpt_loading_attributes(self):
         """
         """
         Sets checkpoint loading related attributes according to self.checkpoint_params
         Sets checkpoint loading related attributes according to self.checkpoint_params
@@ -1783,7 +1740,6 @@ class Trainer:
                                                set_net=self.set_net,
                                                set_net=self.set_net,
                                                set_ckpt_best_name=self.set_ckpt_best_name,
                                                set_ckpt_best_name=self.set_ckpt_best_name,
                                                reset_best_metric=self._reset_best_metric,
                                                reset_best_metric=self._reset_best_metric,
-                                               build_model=self.build_model,
                                                validate_epoch=self._validate_epoch,
                                                validate_epoch=self._validate_epoch,
                                                set_ema=self.set_ema)
                                                set_ema=self.set_ema)
         else:
         else:
Discard
@@ -10,10 +10,11 @@ from torch.utils.data import DataLoader
 from tqdm import tqdm
 from tqdm import tqdm
 import torch
 import torch
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.training import models
 from super_gradients.training.utils.callbacks import Phase, PhaseCallback, PhaseContext
 from super_gradients.training.utils.callbacks import Phase, PhaseCallback, PhaseContext
 import os
 import os
 from enum import Enum
 from enum import Enum
-from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model, read_ckpt_state_dict
+from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model
 from super_gradients.training.utils import get_param
 from super_gradients.training.utils import get_param
 from super_gradients.training.utils.distributed_training_utils import get_local_rank, \
 from super_gradients.training.utils.distributed_training_utils import get_local_rank, \
     get_world_size
     get_world_size
@@ -261,21 +262,6 @@ class QATCallback(PhaseCallback):
 
 
     def __call__(self, context: PhaseContext):
     def __call__(self, context: PhaseContext):
         if context.epoch == self.start_epoch:
         if context.epoch == self.start_epoch:
-            # SET CHECKPOINT PARAMS SO WE LOAD THE BEST CHECKPOINT SO FAR
-            checkpoint_params_qat = context.checkpoint_params.to_dict()
-            checkpoint_params_qat['ckpt_name'] = 'ckpt_best.pth'
-
-            if self.calibrated_model_path is not None:
-                checkpoint_params_qat['external_checkpoint_path'] = self.calibrated_model_path
-                checkpoint_params_qat['load_ema_as_net'] = 'ema_net' in read_ckpt_state_dict(self.calibrated_model_path).keys()
-                checkpoint_params_qat['load_checkpoint'] = True
-            elif self.start_epoch > 0:
-                checkpoint_params_qat['load_ema_as_net'] = context.training_params.ema
-                checkpoint_params_qat['load_checkpoint'] = True
-                if checkpoint_params_qat['load_ema_as_net']:
-                    logger.warning("EMA net loaded from best checkpoint, continuing QAT without EMA.")
-                    context.context_methods.set_ema(False)
-
             # REMOVE REFERENCES TO NETWORK AND CLEAN GPU MEMORY BEFORE BUILDING THE NEW NET
             # REMOVE REFERENCES TO NETWORK AND CLEAN GPU MEMORY BEFORE BUILDING THE NEW NET
             context.context_methods.set_net(None)
             context.context_methods.set_net(None)
             context.net = None
             context.net = None
@@ -283,9 +269,14 @@ class QATCallback(PhaseCallback):
 
 
             # BUILD THE SAME MODEL BUT WITH FAKE QUANTIZED MODULES, AND LOAD BEST CHECKPOINT TO IT
             # BUILD THE SAME MODEL BUT WITH FAKE QUANTIZED MODULES, AND LOAD BEST CHECKPOINT TO IT
             self._initialize_quant_modules()
             self._initialize_quant_modules()
-            context.context_methods.build_model(architecture=context.architecture,
-                                                arch_params=context.arch_params.to_dict(),
-                                                checkpoint_params=checkpoint_params_qat)
+
+            if self.calibrated_model_path is not None:
+                checkpoint_path = self.calibrated_model_path
+            elif self.start_epoch > 0:
+                checkpoint_path = os.path.join(context.ckpt_dir, 'ckpt_best.pth')
+
+            qat_net = models.get(context.architecture, arch_params=context.arch_params.to_dict(), checkpoint_path=checkpoint_path)
+
             _deactivate_quant_modules_wrapping()
             _deactivate_quant_modules_wrapping()
 
 
             # UPDATE CONTEXT'S NET REFERENCE
             # UPDATE CONTEXT'S NET REFERENCE
@@ -300,6 +291,9 @@ class QATCallback(PhaseCallback):
             # SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
             # SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
             context.context_methods.set_ckpt_best_name('qat_ckpt_best.pth')
             context.context_methods.set_ckpt_best_name('qat_ckpt_best.pth')
 
 
+            # FINALLY, SET THE QAT NET TO CONTINUE TRAINING
+            context.context_methods.set_net(qat_net)
+
     def _calibrate_model(self, context: PhaseContext):
     def _calibrate_model(self, context: PhaseContext):
         """
         """
         Performs model calibration (collecting stats and setting amax for the fake quantized moduls)
         Performs model calibration (collecting stats and setting amax for the fake quantized moduls)
Discard