Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#638 Bug/sg 000 update kd train from config

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-000_update_kd_train_from_config
1 changed files with 28 additions and 9 deletions
  1. 28
    9
      src/super_gradients/training/kd_trainer/kd_trainer.py
@@ -1,15 +1,16 @@
 import hydra
 import hydra
 import torch.nn
 import torch.nn
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 
 
+from super_gradients.training.utils.distributed_training_utils import setup_device
 from super_gradients.common import MultiGPUMode
 from super_gradients.common import MultiGPUMode
 from super_gradients.training.dataloaders import dataloaders
 from super_gradients.training.dataloaders import dataloaders
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
-from typing import Union
+from typing import Union, Dict
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
@@ -25,7 +26,6 @@ from super_gradients.training.exceptions.kd_trainer_exceptions import (
 )
 )
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.ema import KDModelEMA
-from super_gradients.training.utils.sg_trainer_utils import parse_args
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -47,11 +47,16 @@ class KDTrainer(Trainer):
         @return: output of kd_trainer.train(...) (i.e results tuple)
         @return: output of kd_trainer.train(...) (i.e results tuple)
         """
         """
         # INSTANTIATE ALL OBJECTS IN CFG
         # INSTANTIATE ALL OBJECTS IN CFG
-        cfg = hydra.utils.instantiate(cfg)
+        setup_device(
+            device=core_utils.get_param(cfg, "device"),
+            multi_gpu=core_utils.get_param(cfg, "multi_gpu"),
+            num_gpus=core_utils.get_param(cfg, "num_gpus"),
+        )
 
 
-        kwargs = parse_args(cfg, cls.__init__)
+        # INSTANTIATE ALL OBJECTS IN CFG
+        cfg = hydra.utils.instantiate(cfg)
 
 
-        trainer = KDTrainer(**kwargs)
+        trainer = KDTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)
 
 
         # INSTANTIATE DATA LOADERS
         # INSTANTIATE DATA LOADERS
         train_dataloader = dataloaders.get(
         train_dataloader = dataloaders.get(
@@ -80,6 +85,8 @@ class KDTrainer(Trainer):
             load_backbone=cfg.teacher_checkpoint_params.load_backbone,
             load_backbone=cfg.teacher_checkpoint_params.load_backbone,
         )
         )
 
 
+        recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
+
         # TRAIN
         # TRAIN
         trainer.train(
         trainer.train(
             training_params=cfg.training_hyperparams,
             training_params=cfg.training_hyperparams,
@@ -90,6 +97,7 @@ class KDTrainer(Trainer):
             run_teacher_on_eval=cfg.run_teacher_on_eval,
             run_teacher_on_eval=cfg.run_teacher_on_eval,
             train_loader=train_dataloader,
             train_loader=train_dataloader,
             valid_loader=val_dataloader,
             valid_loader=val_dataloader,
+            additional_configs_to_log=recipe_logged_cfg,
         )
         )
 
 
     def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
     def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
@@ -275,20 +283,22 @@ class KDTrainer(Trainer):
     def train(
     def train(
         self,
         self,
         model: KDModule = None,
         model: KDModule = None,
-        training_params: dict = dict(),
+        training_params: Dict = None,
         student: SgModule = None,
         student: SgModule = None,
         teacher: torch.nn.Module = None,
         teacher: torch.nn.Module = None,
         kd_architecture: Union[KDModule.__class__, str] = "kd_module",
         kd_architecture: Union[KDModule.__class__, str] = "kd_module",
-        kd_arch_params: dict = dict(),
+        kd_arch_params: Dict = None,
         run_teacher_on_eval=False,
         run_teacher_on_eval=False,
         train_loader: DataLoader = None,
         train_loader: DataLoader = None,
         valid_loader: DataLoader = None,
         valid_loader: DataLoader = None,
+        additional_configs_to_log: Dict = None,
         *args,
         *args,
         **kwargs,
         **kwargs,
     ):
     ):
         """
         """
         Trains the student network (wrapped in KDModule network).
         Trains the student network (wrapped in KDModule network).
 
 
+
         :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
         :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
             student and teacher (default=None)
             student and teacher (default=None)
         :param training_params: dict, Same as in Trainer.train()
         :param training_params: dict, Same as in Trainer.train()
@@ -299,12 +309,21 @@ class KDTrainer(Trainer):
         :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
         :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
         :param train_loader: Dataloader for train set.
         :param train_loader: Dataloader for train set.
         :param valid_loader: Dataloader for validation.
         :param valid_loader: Dataloader for validation.
+        :param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training's
+                sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}, (optional, default=None)
         """
         """
         kd_net = self.net or model
         kd_net = self.net or model
+        kd_arch_params = kd_arch_params or dict()
         if kd_net is None:
         if kd_net is None:
             if student is None or teacher is None:
             if student is None or teacher is None:
                 raise ValueError("Must pass student and teacher models or net (KDModule).")
                 raise ValueError("Must pass student and teacher models or net (KDModule).")
             kd_net = self._instantiate_kd_net(
             kd_net = self._instantiate_kd_net(
                 arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
                 arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
             )
             )
-        super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader)
+        super(KDTrainer, self).train(
+            model=kd_net,
+            training_params=training_params,
+            train_loader=train_loader,
+            valid_loader=valid_loader,
+            additional_configs_to_log=additional_configs_to_log,
+        )
Discard