Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#544 Feature/sg 456 centralize ddp setup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-456-centralise_ddp_setup
@@ -3,6 +3,7 @@ from super_gradients.training import ARCHITECTURES, losses, utils, datasets_util
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_recipe_example import train_from_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
 from super_gradients.sanity_check import env_sanity_check
 from super_gradients.sanity_check import env_sanity_check
+from super_gradients.training.utils.distributed_training_utils import setup_device
 
 
 __all__ = [
 __all__ = [
     "ARCHITECTURES",
     "ARCHITECTURES",
@@ -18,6 +19,7 @@ __all__ = [
     "train_from_recipe",
     "train_from_recipe",
     "train_from_kd_recipe",
     "train_from_kd_recipe",
     "env_sanity_check",
     "env_sanity_check",
+    "setup_device",
 ]
 ]
 
 
 __version__ = "3.0.5"
 __version__ = "3.0.5"
Discard
@@ -1,8 +1,11 @@
 import argparse
 import argparse
 import sys
 import sys
 from typing import Any
 from typing import Any
+from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
 
 
+logger = get_logger(__name__)
+
 EXTRA_ARGS = []
 EXTRA_ARGS = []
 
 
 
 
@@ -18,3 +21,11 @@ def pop_arg(arg_name: str, default_value: Any = None) -> Any:
         EXTRA_ARGS.append(val)
         EXTRA_ARGS.append(val)
         sys.argv.remove(val)
         sys.argv.remove(val)
     return vars(args)[arg_name]
     return vars(args)[arg_name]
+
+
+def pop_local_rank() -> int:
+    """Pop the python arg "local-rank". If exists inform the user with a log, otherwise return -1."""
+    local_rank = pop_arg("local_rank", default_value=-1)
+    if local_rank != -1:
+        logger.info("local_rank was automatically parsed from your config.")
+    return local_rank
Discard
@@ -2,12 +2,9 @@ import os
 import socket
 import socket
 from functools import wraps
 from functools import wraps
 
 
-from super_gradients.common.environment.argparse_utils import pop_arg
+from super_gradients.common.environment.device_utils import device_config
 from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
 from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
-
-
-DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=-1))
-INIT_TRAINER = False
+from super_gradients.common.environment.argparse_utils import pop_local_rank
 
 
 
 
 def init_trainer():
 def init_trainer():
@@ -15,28 +12,14 @@ def init_trainer():
     Initialize the super_gradients environment.
     Initialize the super_gradients environment.
 
 
     This function should be the first thing to be called by any code running super_gradients.
     This function should be the first thing to be called by any code running super_gradients.
-    It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
     """
     """
-    global INIT_TRAINER, DDP_LOCAL_RANK
-
-    if not INIT_TRAINER:
-        register_hydra_resolvers()
-
-        # We pop local_rank if it was specified in the args, because it would break
-        args_local_rank = pop_arg("local_rank", default_value=-1)
-
-        # Set local_rank with priority order (env variable > args.local_rank > args.default_value)
-        DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
-        INIT_TRAINER = True
+    register_hydra_resolvers()
+    pop_local_rank()
 
 
 
 
 def is_distributed() -> bool:
 def is_distributed() -> bool:
-    return DDP_LOCAL_RANK >= 0
-
-
-def is_rank_0() -> bool:
-    """Check if the node was launched with torch.distributed.launch and if the node is of rank 0"""
-    return os.getenv("LOCAL_RANK") == "0"
+    """Check if current process is a DDP subprocess."""
+    return device_config.assigned_rank >= 0
 
 
 
 
 def is_launched_using_sg():
 def is_launched_using_sg():
@@ -55,7 +38,9 @@ def is_main_process():
     """
     """
     if not is_distributed():  # If no DDP, or DDP launching process
     if not is_distributed():  # If no DDP, or DDP launching process
         return True
         return True
-    elif is_rank_0() and not is_launched_using_sg():  # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
+    elif (
+        device_config.assigned_rank == 0 and not is_launched_using_sg()
+    ):  # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
         return True
         return True
     else:
     else:
         return False
         return False
@@ -74,7 +59,7 @@ def multi_process_safe(func):
 
 
     @wraps(func)
     @wraps(func)
     def wrapper(*args, **kwargs):
     def wrapper(*args, **kwargs):
-        if DDP_LOCAL_RANK <= 0:
+        if device_config.assigned_rank <= 0:
             return func(*args, **kwargs)
             return func(*args, **kwargs)
         else:
         else:
             return do_nothing(*args, **kwargs)
             return do_nothing(*args, **kwargs)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
  1. import os
  2. import dataclasses
  3. import torch
  4. from super_gradients.common.environment.argparse_utils import pop_local_rank
  5. __all__ = ["device_config"]
  6. def _get_assigned_rank() -> int:
  7. """Get the rank assigned by DDP launcher. If not DDP subprocess, return -1."""
  8. if os.getenv("LOCAL_RANK") is not None:
  9. return int(os.getenv("LOCAL_RANK"))
  10. else:
  11. return pop_local_rank()
  12. @dataclasses.dataclass
  13. class DeviceConfig:
  14. device: str = "cuda" if torch.cuda.is_available() else "cpu"
  15. multi_gpu: str = None
  16. assigned_rank: str = dataclasses.field(default=_get_assigned_rank(), init=False)
  17. # Singleton holding the device information
  18. device_config = DeviceConfig()
Discard
@@ -12,7 +12,6 @@ Main purpose is to demonstrate training in SG with minimal abstraction and maxim
 
 
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
 from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
-from super_gradients.training import MultiGPUMode
 from torch.optim import ASGD
 from torch.optim import ASGD
 from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau
 from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau
 from torch.nn import CrossEntropyLoss
 from torch.nn import CrossEntropyLoss
@@ -49,7 +48,7 @@ phase_callbacks = [
 ]
 ]
 
 
 # Bring everything together with Trainer and start training
 # Bring everything together with Trainer and start training
-trainer = Trainer("Cifar10_external_objects_example", multi_gpu=MultiGPUMode.OFF)
+trainer = Trainer("Cifar10_external_objects_example")
 
 
 train_params = {
 train_params = {
     "max_epochs": 300,
     "max_epochs": 300,
Discard
@@ -14,11 +14,15 @@ from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import get_param, HpmStruct
 from super_gradients.training.utils import get_param, HpmStruct
-from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, \
-    load_checkpoint_to_model
-from super_gradients.training.exceptions.kd_trainer_exceptions import ArchitectureKwargsException, \
-    UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
-    TeacherKnowledgeException, UndefinedNumClassesException
+from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
+from super_gradients.training.exceptions.kd_trainer_exceptions import (
+    ArchitectureKwargsException,
+    UnsupportedKDArchitectureException,
+    InconsistentParamsException,
+    UnsupportedKDModelArgException,
+    TeacherKnowledgeException,
+    UndefinedNumClassesException,
+)
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 from super_gradients.training.utils.sg_trainer_utils import parse_args
@@ -27,8 +31,7 @@ logger = get_logger(__name__)
 
 
 
 
 class KDTrainer(Trainer):
 class KDTrainer(Trainer):
-    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 ckpt_root_dir: str = None):
+    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = None, ckpt_root_dir: str = None):
         super().__init__(experiment_name=experiment_name, device=device, multi_gpu=multi_gpu, ckpt_root_dir=ckpt_root_dir)
         super().__init__(experiment_name=experiment_name, device=device, multi_gpu=multi_gpu, ckpt_root_dir=ckpt_root_dir)
         self.student_architecture = None
         self.student_architecture = None
         self.teacher_architecture = None
         self.teacher_architecture = None
@@ -51,31 +54,43 @@ class KDTrainer(Trainer):
         trainer = KDTrainer(**kwargs)
         trainer = KDTrainer(**kwargs)
 
 
         # INSTANTIATE DATA LOADERS
         # INSTANTIATE DATA LOADERS
-        train_dataloader = dataloaders.get(name=cfg.train_dataloader,
-                                           dataset_params=cfg.dataset_params.train_dataset_params,
-                                           dataloader_params=cfg.dataset_params.train_dataloader_params)
-
-        val_dataloader = dataloaders.get(name=cfg.val_dataloader,
-                                         dataset_params=cfg.dataset_params.val_dataset_params,
-                                         dataloader_params=cfg.dataset_params.val_dataloader_params)
-
-        student = models.get(cfg.student_architecture, arch_params=cfg.student_arch_params,
-                             strict_load=cfg.student_checkpoint_params.strict_load,
-                             pretrained_weights=cfg.student_checkpoint_params.pretrained_weights,
-                             checkpoint_path=cfg.student_checkpoint_params.checkpoint_path,
-                             load_backbone=cfg.student_checkpoint_params.load_backbone)
-
-        teacher = models.get(cfg.teacher_architecture, arch_params=cfg.teacher_arch_params,
-                             strict_load=cfg.teacher_checkpoint_params.strict_load,
-                             pretrained_weights=cfg.teacher_checkpoint_params.pretrained_weights,
-                             checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path,
-                             load_backbone=cfg.teacher_checkpoint_params.load_backbone)
+        train_dataloader = dataloaders.get(
+            name=cfg.train_dataloader, dataset_params=cfg.dataset_params.train_dataset_params, dataloader_params=cfg.dataset_params.train_dataloader_params
+        )
+
+        val_dataloader = dataloaders.get(
+            name=cfg.val_dataloader, dataset_params=cfg.dataset_params.val_dataset_params, dataloader_params=cfg.dataset_params.val_dataloader_params
+        )
+
+        student = models.get(
+            cfg.student_architecture,
+            arch_params=cfg.student_arch_params,
+            strict_load=cfg.student_checkpoint_params.strict_load,
+            pretrained_weights=cfg.student_checkpoint_params.pretrained_weights,
+            checkpoint_path=cfg.student_checkpoint_params.checkpoint_path,
+            load_backbone=cfg.student_checkpoint_params.load_backbone,
+        )
+
+        teacher = models.get(
+            cfg.teacher_architecture,
+            arch_params=cfg.teacher_arch_params,
+            strict_load=cfg.teacher_checkpoint_params.strict_load,
+            pretrained_weights=cfg.teacher_checkpoint_params.pretrained_weights,
+            checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path,
+            load_backbone=cfg.teacher_checkpoint_params.load_backbone,
+        )
 
 
         # TRAIN
         # TRAIN
-        trainer.train(training_params=cfg.training_hyperparams, student=student, teacher=teacher,
-                      kd_architecture=cfg.architecture, kd_arch_params=cfg.arch_params,
-                      run_teacher_on_eval=cfg.run_teacher_on_eval,
-                      train_loader=train_dataloader, valid_loader=val_dataloader)
+        trainer.train(
+            training_params=cfg.training_hyperparams,
+            student=student,
+            teacher=teacher,
+            kd_architecture=cfg.architecture,
+            kd_arch_params=cfg.arch_params,
+            run_teacher_on_eval=cfg.run_teacher_on_eval,
+            train_loader=train_dataloader,
+            valid_loader=val_dataloader,
+        )
 
 
     def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
     def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
         student_architecture = get_param(kwargs, "student_architecture")
         student_architecture = get_param(kwargs, "student_architecture")
@@ -83,7 +98,7 @@ class KDTrainer(Trainer):
         student_arch_params = get_param(kwargs, "student_arch_params")
         student_arch_params = get_param(kwargs, "student_arch_params")
         teacher_arch_params = get_param(kwargs, "teacher_arch_params")
         teacher_arch_params = get_param(kwargs, "teacher_arch_params")
 
 
-        if get_param(checkpoint_params, 'pretrained_weights') is not None:
+        if get_param(checkpoint_params, "pretrained_weights") is not None:
             raise UnsupportedKDModelArgException("pretrained_weights", "checkpoint_params")
             raise UnsupportedKDModelArgException("pretrained_weights", "checkpoint_params")
 
 
         if not isinstance(architecture, KDModule):
         if not isinstance(architecture, KDModule):
@@ -95,24 +110,23 @@ class KDTrainer(Trainer):
         # DERIVE NUMBER OF CLASSES FROM DATASET INTERFACE IF NOT SPECIFIED OR ARCH PARAMS FOR TEACHER AND STUDENT
         # DERIVE NUMBER OF CLASSES FROM DATASET INTERFACE IF NOT SPECIFIED OR ARCH PARAMS FOR TEACHER AND STUDENT
         self._validate_num_classes(student_arch_params, teacher_arch_params)
         self._validate_num_classes(student_arch_params, teacher_arch_params)
 
 
-        arch_params['num_classes'] = student_arch_params['num_classes']
+        arch_params["num_classes"] = student_arch_params["num_classes"]
 
 
         # MAKE SURE TEACHER'S PRETRAINED NUM CLASSES EQUALS TO THE ONES BELONGING TO STUDENT AS WE CAN'T REPLACE
         # MAKE SURE TEACHER'S PRETRAINED NUM CLASSES EQUALS TO THE ONES BELONGING TO STUDENT AS WE CAN'T REPLACE
         # THE TEACHER'S HEAD
         # THE TEACHER'S HEAD
-        teacher_pretrained_weights = core_utils.get_param(checkpoint_params, 'teacher_pretrained_weights',
-                                                          default_val=None)
+        teacher_pretrained_weights = core_utils.get_param(checkpoint_params, "teacher_pretrained_weights", default_val=None)
         if teacher_pretrained_weights is not None:
         if teacher_pretrained_weights is not None:
             teacher_pretrained_num_classes = PRETRAINED_NUM_CLASSES[teacher_pretrained_weights]
             teacher_pretrained_num_classes = PRETRAINED_NUM_CLASSES[teacher_pretrained_weights]
-            if teacher_pretrained_num_classes != teacher_arch_params['num_classes']:
-                raise InconsistentParamsException("Pretrained dataset number of classes", "teacher's arch params",
-                                                  "number of classes", "student's number of classes")
+            if teacher_pretrained_num_classes != teacher_arch_params["num_classes"]:
+                raise InconsistentParamsException(
+                    "Pretrained dataset number of classes", "teacher's arch params", "number of classes", "student's number of classes"
+                )
 
 
         teacher_checkpoint_path = get_param(checkpoint_params, "teacher_checkpoint_path")
         teacher_checkpoint_path = get_param(checkpoint_params, "teacher_checkpoint_path")
         load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint")
         load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint")
 
 
         # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD
         # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD
-        if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(
-                teacher_architecture, torch.nn.Module)):
+        if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(teacher_architecture, torch.nn.Module)):
             raise TeacherKnowledgeException()
             raise TeacherKnowledgeException()
 
 
     def _validate_num_classes(self, student_arch_params, teacher_arch_params):
     def _validate_num_classes(self, student_arch_params, teacher_arch_params):
@@ -125,9 +139,8 @@ class KDTrainer(Trainer):
         """
         """
         self._validate_subnet_num_classes(student_arch_params)
         self._validate_subnet_num_classes(student_arch_params)
         self._validate_subnet_num_classes(teacher_arch_params)
         self._validate_subnet_num_classes(teacher_arch_params)
-        if teacher_arch_params['num_classes'] != student_arch_params['num_classes']:
-            raise InconsistentParamsException("num_classes", "student_arch_params", "num_classes",
-                                              "teacher_arch_params")
+        if teacher_arch_params["num_classes"] != student_arch_params["num_classes"]:
+            raise InconsistentParamsException("num_classes", "student_arch_params", "num_classes", "teacher_arch_params")
 
 
     def _validate_subnet_num_classes(self, subnet_arch_params):
     def _validate_subnet_num_classes(self, subnet_arch_params):
         """
         """
@@ -138,14 +151,13 @@ class KDTrainer(Trainer):
 
 
         """
         """
 
 
-        if 'num_classes' not in subnet_arch_params.keys():
+        if "num_classes" not in subnet_arch_params.keys():
             if self.dataset_interface is None:
             if self.dataset_interface is None:
                 raise UndefinedNumClassesException()
                 raise UndefinedNumClassesException()
             else:
             else:
-                subnet_arch_params['num_classes'] = len(self.classes)
+                subnet_arch_params["num_classes"] = len(self.classes)
 
 
-    def _instantiate_net(self, architecture: Union[KDModule, KDModule.__class__, str], arch_params: dict,
-                         checkpoint_params: dict, *args, **kwargs) -> tuple:
+    def _instantiate_net(self, architecture: Union[KDModule, KDModule.__class__, str], arch_params: dict, checkpoint_params: dict, *args, **kwargs) -> tuple:
         """
         """
         Instantiates kd_module according to architecture and arch_params, handles pretrained weights for the student
         Instantiates kd_module according to architecture and arch_params, handles pretrained weights for the student
          and teacher networks, and the required module manipulation (i.e head replacement) for the teacher network.
          and teacher networks, and the required module manipulation (i.e head replacement) for the teacher network.
@@ -164,13 +176,11 @@ class KDTrainer(Trainer):
         teacher_arch_params = get_param(kwargs, "teacher_arch_params")
         teacher_arch_params = get_param(kwargs, "teacher_arch_params")
         student_arch_params = core_utils.HpmStruct(**student_arch_params)
         student_arch_params = core_utils.HpmStruct(**student_arch_params)
         teacher_arch_params = core_utils.HpmStruct(**teacher_arch_params)
         teacher_arch_params = core_utils.HpmStruct(**teacher_arch_params)
-        student_pretrained_weights = get_param(checkpoint_params, 'student_pretrained_weights')
-        teacher_pretrained_weights = get_param(checkpoint_params, 'teacher_pretrained_weights')
+        student_pretrained_weights = get_param(checkpoint_params, "student_pretrained_weights")
+        teacher_pretrained_weights = get_param(checkpoint_params, "teacher_pretrained_weights")
 
 
-        student = super()._instantiate_net(student_architecture, student_arch_params,
-                                           {"pretrained_weights": student_pretrained_weights})
-        teacher = super()._instantiate_net(teacher_architecture, teacher_arch_params,
-                                           {"pretrained_weights": teacher_pretrained_weights})
+        student = super()._instantiate_net(student_architecture, student_arch_params, {"pretrained_weights": student_pretrained_weights})
+        teacher = super()._instantiate_net(teacher_architecture, teacher_arch_params, {"pretrained_weights": teacher_pretrained_weights})
 
 
         run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False)
         run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False)
 
 
@@ -179,11 +189,9 @@ class KDTrainer(Trainer):
     def _instantiate_kd_net(self, arch_params, architecture, run_teacher_on_eval, student, teacher):
     def _instantiate_kd_net(self, arch_params, architecture, run_teacher_on_eval, student, teacher):
         if isinstance(architecture, str):
         if isinstance(architecture, str):
             architecture_cls = KD_ARCHITECTURES[architecture]
             architecture_cls = KD_ARCHITECTURES[architecture]
-            net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher,
-                                   run_teacher_on_eval=run_teacher_on_eval)
+            net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval)
         elif isinstance(architecture, KDModule.__class__):
         elif isinstance(architecture, KDModule.__class__):
-            net = architecture(arch_params=arch_params, student=student, teacher=teacher,
-                               run_teacher_on_eval=run_teacher_on_eval)
+            net = architecture(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval)
         else:
         else:
             net = architecture
             net = architecture
         return net
         return net
@@ -201,18 +209,18 @@ class KDTrainer(Trainer):
             #  WARN THAT TEACHER_CKPT WILL OVERRIDE TEACHER'S PRETRAINED WEIGHTS
             #  WARN THAT TEACHER_CKPT WILL OVERRIDE TEACHER'S PRETRAINED WEIGHTS
             teacher_pretrained_weights = get_param(self.checkpoint_params, "teacher_pretrained_weights")
             teacher_pretrained_weights = get_param(self.checkpoint_params, "teacher_pretrained_weights")
             if teacher_pretrained_weights:
             if teacher_pretrained_weights:
-                logger.warning(
-                    teacher_checkpoint_path + " checkpoint is "
-                                              "overriding " + teacher_pretrained_weights + " for teacher model")
+                logger.warning(teacher_checkpoint_path + " checkpoint is " "overriding " + teacher_pretrained_weights + " for teacher model")
 
 
             # ALWAYS LOAD ITS EMA IF IT EXISTS
             # ALWAYS LOAD ITS EMA IF IT EXISTS
-            load_teachers_ema = 'ema_net' in read_ckpt_state_dict(teacher_checkpoint_path).keys()
-            load_checkpoint_to_model(ckpt_local_path=teacher_checkpoint_path,
-                                     load_backbone=False,
-                                     net=teacher_net,
-                                     strict='no_key_matching',
-                                     load_weights_only=True,
-                                     load_ema_as_net=load_teachers_ema)
+            load_teachers_ema = "ema_net" in read_ckpt_state_dict(teacher_checkpoint_path).keys()
+            load_checkpoint_to_model(
+                ckpt_local_path=teacher_checkpoint_path,
+                load_backbone=False,
+                net=teacher_net,
+                strict="no_key_matching",
+                load_weights_only=True,
+                load_ema_as_net=load_teachers_ema,
+            )
 
 
         super(KDTrainer, self)._load_checkpoint_to_model()
         super(KDTrainer, self)._load_checkpoint_to_model()
 
 
@@ -229,15 +237,17 @@ class KDTrainer(Trainer):
         Creates a training hyper param config for logging with additional KD related hyper params.
         Creates a training hyper param config for logging with additional KD related hyper params.
         """
         """
         hyper_param_config = super()._get_hyper_param_config()
         hyper_param_config = super()._get_hyper_param_config()
-        hyper_param_config.update({"student_architecture": self.student_architecture,
-                                   "teacher_architecture": self.teacher_architecture,
-                                   "student_arch_params": self.student_arch_params,
-                                   "teacher_arch_params": self.teacher_arch_params
-                                   })
+        hyper_param_config.update(
+            {
+                "student_architecture": self.student_architecture,
+                "teacher_architecture": self.teacher_architecture,
+                "student_arch_params": self.student_arch_params,
+                "teacher_arch_params": self.teacher_arch_params,
+            }
+        )
         return hyper_param_config
         return hyper_param_config
 
 
-    def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15,
-                               exp_activation: bool = True) -> KDModelEMA:
+    def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
         """Instantiate KD ema model for KDModule.
         """Instantiate KD ema model for KDModule.
 
 
         If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
         If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
@@ -262,10 +272,20 @@ class KDTrainer(Trainer):
         state["net"] = best_net.state_dict()
         state["net"] = best_net.state_dict()
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
 
 
-    def train(self, model: KDModule = None, training_params: dict = dict(), student: SgModule = None,
-              teacher: torch.nn.Module = None, kd_architecture: Union[KDModule.__class__, str] = 'kd_module',
-              kd_arch_params: dict = dict(), run_teacher_on_eval=False, train_loader: DataLoader = None,
-              valid_loader: DataLoader = None, *args, **kwargs):
+    def train(
+        self,
+        model: KDModule = None,
+        training_params: dict = dict(),
+        student: SgModule = None,
+        teacher: torch.nn.Module = None,
+        kd_architecture: Union[KDModule.__class__, str] = "kd_module",
+        kd_arch_params: dict = dict(),
+        run_teacher_on_eval=False,
+        train_loader: DataLoader = None,
+        valid_loader: DataLoader = None,
+        *args,
+        **kwargs,
+    ):
         """
         """
         Trains the student network (wrapped in KDModule network).
         Trains the student network (wrapped in KDModule network).
 
 
@@ -284,10 +304,7 @@ class KDTrainer(Trainer):
         if kd_net is None:
         if kd_net is None:
             if student is None or teacher is None:
             if student is None or teacher is None:
                 raise ValueError("Must pass student and teacher models or net (KDModule).")
                 raise ValueError("Must pass student and teacher models or net (KDModule).")
-            kd_net = self._instantiate_kd_net(arch_params=HpmStruct(**kd_arch_params),
-                                              architecture=kd_architecture,
-                                              run_teacher_on_eval=run_teacher_on_eval,
-                                              student=student,
-                                              teacher=teacher)
-        super(KDTrainer, self).train(model=kd_net, training_params=training_params,
-                                     train_loader=train_loader, valid_loader=valid_loader)
+            kd_net = self._instantiate_kd_net(
+                arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
+            )
+        super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader)
Discard
@@ -17,15 +17,13 @@ from piptools.scripts.sync import _get_installed_distributions
 
 
 from torch.utils.data.distributed import DistributedSampler
 from torch.utils.data.distributed import DistributedSampler
 
 
-from super_gradients.common.factories.type_factory import TypeFactory
 from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
 from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
 
 
 from super_gradients.common.factories.callbacks_factory import CallbacksFactory
 from super_gradients.common.factories.callbacks_factory import CallbacksFactory
 from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
 from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.environment import ddp_utils
-from super_gradients.common.abstractions.abstract_logger import get_logger, mute_current_process
+from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.losses_factory import LossesFactory
 from super_gradients.common.factories.losses_factory import LossesFactory
 from super_gradients.common.factories.metrics_factory import MetricsFactory
 from super_gradients.common.factories.metrics_factory import MetricsFactory
@@ -36,8 +34,8 @@ from super_gradients.training import utils as core_utils, models, dataloaders
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import sg_trainer_utils, get_param
 from super_gradients.training.utils import sg_trainer_utils, get_param
-from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args, log_main_training_params
-from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
+from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
+from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
 from super_gradients.training.metrics.metric_utils import (
 from super_gradients.training.metrics.metric_utils import (
     get_metrics_titles,
     get_metrics_titles,
     get_metrics_results_tuple,
     get_metrics_results_tuple,
@@ -51,11 +49,14 @@ from super_gradients.training.utils.distributed_training_utils import (
     reduce_results_tuple_for_ddp,
     reduce_results_tuple_for_ddp,
     compute_precise_bn_stats,
     compute_precise_bn_stats,
     setup_device,
     setup_device,
-    require_gpu_setup,
     get_gpu_mem_utilization,
     get_gpu_mem_utilization,
     get_world_size,
     get_world_size,
     get_local_rank,
     get_local_rank,
+    require_ddp_setup,
+    get_device_ids,
+    is_ddp_subprocess,
     wait_for_the_master,
     wait_for_the_master,
+    DDPNotSetupException,
 )
 )
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.optimizer_utils import build_optimizer
 from super_gradients.training.utils.optimizer_utils import build_optimizer
@@ -81,6 +82,7 @@ from super_gradients.training.utils.callbacks import (
     ContextSgMethods,
     ContextSgMethods,
     LRCallbackBase,
     LRCallbackBase,
 )
 )
+from super_gradients.common.environment.device_utils import device_config
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
 from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
@@ -104,7 +106,7 @@ class Trainer:
         returns the test loss, accuracy and runtime
         returns the test loss, accuracy and runtime
     """
     """
 
 
-    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF, ckpt_root_dir: str = None):
+    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = None, ckpt_root_dir: str = None):
         """
         """
 
 
         :param experiment_name:                      Used for logging and loading purposes
         :param experiment_name:                      Used for logging and loading purposes
@@ -117,9 +119,20 @@ class Trainer:
                                                 pkg_resources.resource_filename('checkpoints', "") exists and will be used.
                                                 pkg_resources.resource_filename('checkpoints', "") exists and will be used.
 
 
         """
         """
+
+        # This should later me removed
+        if device is not None or multi_gpu is not None:
+            raise KeyError(
+                "Trainer does not accept anymore 'device' and 'multi_gpu' as argument. "
+                "Both should instead be passed to "
+                "super_gradients.setup_device(device=..., multi_gpu=..., num_gpus=...)"
+            )
+
+        if require_ddp_setup():
+            raise DDPNotSetupException()
+
         # SET THE EMPTY PROPERTIES
         # SET THE EMPTY PROPERTIES
         self.net, self.architecture, self.arch_params, self.dataset_interface = None, None, None, None
         self.net, self.architecture, self.arch_params, self.dataset_interface = None, None, None, None
-        self.device, self.multi_gpu = None, None
         self.ema = None
         self.ema = None
         self.ema_model = None
         self.ema_model = None
         self.sg_logger = None
         self.sg_logger = None
@@ -136,7 +149,8 @@ class Trainer:
         self.load_checkpoint = False
         self.load_checkpoint = False
         self.load_backbone = False
         self.load_backbone = False
         self.load_weights_only = False
         self.load_weights_only = False
-        self.ddp_silent_mode = False
+        self.ddp_silent_mode = is_ddp_subprocess()
+
         self.source_ckpt_folder_name = None
         self.source_ckpt_folder_name = None
         self.model_weight_averaging = None
         self.model_weight_averaging = None
         self.average_model_checkpoint_filename = "average_model.pth"
         self.average_model_checkpoint_filename = "average_model.pth"
@@ -166,9 +180,6 @@ class Trainer:
 
 
         self.checkpoints_dir_path = get_checkpoints_dir_path(experiment_name, ckpt_root_dir)
         self.checkpoints_dir_path = get_checkpoints_dir_path(experiment_name, ckpt_root_dir)
 
 
-        # INITIALIZE THE DEVICE FOR THE MODEL
-        self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
-
         # SET THE DEFAULTS
         # SET THE DEFAULTS
         # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
         # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
 
 
@@ -183,6 +194,10 @@ class Trainer:
         self.train_monitored_values = {}
         self.train_monitored_values = {}
         self.valid_monitored_values = {}
         self.valid_monitored_values = {}
 
 
+    @property
+    def device(self) -> str:
+        return device_config.device
+
     @classmethod
     @classmethod
     def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
     def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
         """
         """
@@ -192,14 +207,16 @@ class Trainer:
         @return: the model and the output of trainer.train(...) (i.e results tuple)
         @return: the model and the output of trainer.train(...) (i.e results tuple)
         """
         """
 
 
-        setup_device(multi_gpu=core_utils.get_param(cfg, "multi_gpu", MultiGPUMode.OFF), num_gpus=core_utils.get_param(cfg, "num_gpus"))
+        setup_device(
+            device=core_utils.get_param(cfg, "device"),
+            multi_gpu=core_utils.get_param(cfg, "multi_gpu"),
+            num_gpus=core_utils.get_param(cfg, "num_gpus"),
+        )
 
 
         # INSTANTIATE ALL OBJECTS IN CFG
         # INSTANTIATE ALL OBJECTS IN CFG
         cfg = hydra.utils.instantiate(cfg)
         cfg = hydra.utils.instantiate(cfg)
 
 
-        kwargs = parse_args(cfg, cls.__init__)
-
-        trainer = Trainer(**kwargs)
+        trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)
 
 
         # INSTANTIATE DATA LOADERS
         # INSTANTIATE DATA LOADERS
 
 
@@ -260,14 +277,16 @@ class Trainer:
         :param cfg: The parsed DictConfig from yaml recipe files or a dictionary
         :param cfg: The parsed DictConfig from yaml recipe files or a dictionary
         """
         """
 
 
-        setup_device(multi_gpu=core_utils.get_param(cfg, "multi_gpu", MultiGPUMode.OFF), num_gpus=core_utils.get_param(cfg, "num_gpus"))
+        setup_device(
+            device=core_utils.get_param(cfg, "device"),
+            multi_gpu=core_utils.get_param(cfg, "multi_gpu"),
+            num_gpus=core_utils.get_param(cfg, "num_gpus"),
+        )
 
 
         # INSTANTIATE ALL OBJECTS IN CFG
         # INSTANTIATE ALL OBJECTS IN CFG
         cfg = hydra.utils.instantiate(cfg)
         cfg = hydra.utils.instantiate(cfg)
 
 
-        kwargs = parse_args(cfg, cls.__init__)
-
-        trainer = Trainer(**kwargs)
+        trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)
 
 
         # INSTANTIATE DATA LOADERS
         # INSTANTIATE DATA LOADERS
         val_dataloader = dataloaders.get(
         val_dataloader = dataloaders.get(
@@ -336,21 +355,21 @@ class Trainer:
 
 
     def _net_to_device(self):
     def _net_to_device(self):
         """
         """
-        Manipulates self.net according to self.multi_gpu
+        Manipulates self.net according to device.multi_gpu
         """
         """
-        self.net.to(self.device)
+        self.net.to(device_config.device)
 
 
         # FOR MULTI-GPU TRAINING (not distributed)
         # FOR MULTI-GPU TRAINING (not distributed)
         sync_bn = core_utils.get_param(self.training_params, "sync_bn", default_val=False)
         sync_bn = core_utils.get_param(self.training_params, "sync_bn", default_val=False)
-        if self.multi_gpu == MultiGPUMode.DATA_PARALLEL:
-            self.net = torch.nn.DataParallel(self.net, device_ids=self.device_ids)
-        elif self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
+        if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
+            self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids())
+        elif device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
             if sync_bn:
             if sync_bn:
                 if not self.ddp_silent_mode:
                 if not self.ddp_silent_mode:
                     logger.info("DDP - Using Sync Batch Norm... Training time will be affected accordingly")
                     logger.info("DDP - Using Sync Batch Norm... Training time will be affected accordingly")
-                self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net).to(self.device)
+                self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net).to(device_config.device)
 
 
-            local_rank = int(self.device.split(":")[1])
+            local_rank = int(device_config.device.split(":")[1])
             self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
             self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
 
 
         else:
         else:
@@ -372,7 +391,7 @@ class Trainer:
         # RESET/INIT THE METRIC LOGGERS
         # RESET/INIT THE METRIC LOGGERS
         self._reset_metrics()
         self._reset_metrics()
 
 
-        self.train_metrics.to(self.device)
+        self.train_metrics.to(device_config.device)
         loss_avg_meter = core_utils.utils.AverageMeter()
         loss_avg_meter = core_utils.utils.AverageMeter()
 
 
         context = PhaseContext(
         context = PhaseContext(
@@ -381,7 +400,7 @@ class Trainer:
             metrics_compute_fn=self.train_metrics,
             metrics_compute_fn=self.train_metrics,
             loss_avg_meter=loss_avg_meter,
             loss_avg_meter=loss_avg_meter,
             criterion=self.criterion,
             criterion=self.criterion,
-            device=self.device,
+            device=device_config.device,
             lr_warmup_epochs=self.training_params.lr_warmup_epochs,
             lr_warmup_epochs=self.training_params.lr_warmup_epochs,
             sg_logger=self.sg_logger,
             sg_logger=self.sg_logger,
             train_loader=self.train_loader,
             train_loader=self.train_loader,
@@ -390,7 +409,7 @@ class Trainer:
         )
         )
 
 
         for batch_idx, batch_items in enumerate(progress_bar_train_loader):
         for batch_idx, batch_items in enumerate(progress_bar_train_loader):
-            batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
+            batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True)
             inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
             inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
 
 
             if self.pre_prediction_callback is not None:
             if self.pre_prediction_callback is not None:
@@ -964,7 +983,7 @@ class Trainer:
             logger.warning("Train dataset size % batch_size != 0 and drop_last=False, this might result in smaller " "last batch.")
             logger.warning("Train dataset size % batch_size != 0 and drop_last=False, this might result in smaller " "last batch.")
         self._set_dataset_params()
         self._set_dataset_params()
 
 
-        if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
+        if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
             # Note: the dataloader uses sampler of the batch_sampler when it is not None.
             # Note: the dataloader uses sampler of the batch_sampler when it is not None.
             train_sampler = self.train_loader.batch_sampler.sampler if self.train_loader.batch_sampler is not None else self.train_loader.sampler
             train_sampler = self.train_loader.batch_sampler.sampler if self.train_loader.batch_sampler is not None else self.train_loader.sampler
             if isinstance(train_sampler, SequentialSampler):
             if isinstance(train_sampler, SequentialSampler):
@@ -984,7 +1003,7 @@ class Trainer:
         self._prep_net_for_train()
         self._prep_net_for_train()
 
 
         # SET RANDOM SEED
         # SET RANDOM SEED
-        random_seed(is_ddp=self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=self.device, seed=self.training_params.seed)
+        random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed)
 
 
         silent_mode = self.training_params.silent_mode or self.ddp_silent_mode
         silent_mode = self.training_params.silent_mode or self.ddp_silent_mode
         # METRICS
         # METRICS
@@ -1005,7 +1024,7 @@ class Trainer:
         elif isinstance(self.training_params.loss, nn.Module):
         elif isinstance(self.training_params.loss, nn.Module):
             self.criterion = self.training_params.loss
             self.criterion = self.training_params.loss
 
 
-        self.criterion.to(self.device)
+        self.criterion.to(device_config.device)
 
 
         self.max_epochs = self.training_params.max_epochs
         self.max_epochs = self.training_params.max_epochs
 
 
@@ -1032,7 +1051,7 @@ class Trainer:
         self.run_validation_freq = self.training_params.run_validation_freq
         self.run_validation_freq = self.training_params.run_validation_freq
         validation_results_tuple = (0, 0)
         validation_results_tuple = (0, 0)
         inf_time = 0
         inf_time = 0
-        timer = core_utils.Timer(self.device)
+        timer = core_utils.Timer(device_config.device)
 
 
         # IF THE LR MODE IS NOT DEFAULT TAKE IT FROM THE TRAINING PARAMS
         # IF THE LR MODE IS NOT DEFAULT TAKE IT FROM THE TRAINING PARAMS
         self.lr_mode = self.training_params.lr_mode
         self.lr_mode = self.training_params.lr_mode
@@ -1143,7 +1162,7 @@ class Trainer:
             architecture=self.architecture,
             architecture=self.architecture,
             arch_params=self.arch_params,
             arch_params=self.arch_params,
             metric_to_watch=self.metric_to_watch,
             metric_to_watch=self.metric_to_watch,
-            device=self.device,
+            device=device_config.device,
             context_methods=self._get_context_methods(Phase.PRE_TRAINING),
             context_methods=self._get_context_methods(Phase.PRE_TRAINING),
             ema_model=self.ema_model,
             ema_model=self.ema_model,
         )
         )
@@ -1153,7 +1172,7 @@ class Trainer:
         inputs, _, _ = sg_trainer_utils.unpack_batch_items(first_batch)
         inputs, _, _ = sg_trainer_utils.unpack_batch_items(first_batch)
 
 
         log_main_training_params(
         log_main_training_params(
-            multi_gpu=self.multi_gpu,
+            multi_gpu=device_config.multi_gpu,
             num_gpus=get_world_size(),
             num_gpus=get_world_size(),
             batch_size=len(inputs),
             batch_size=len(inputs),
             batch_accumulate=self.batch_accumulate,
             batch_accumulate=self.batch_accumulate,
@@ -1177,7 +1196,7 @@ class Trainer:
                 # IN DDP- SET_EPOCH WILL CAUSE EVERY PROCESS TO BE EXPOSED TO THE ENTIRE DATASET BY SHUFFLING WITH A
                 # IN DDP- SET_EPOCH WILL CAUSE EVERY PROCESS TO BE EXPOSED TO THE ENTIRE DATASET BY SHUFFLING WITH A
                 # DIFFERENT SEED EACH EPOCH START
                 # DIFFERENT SEED EACH EPOCH START
                 if (
                 if (
-                    self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
+                    device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
                     and hasattr(self.train_loader, "sampler")
                     and hasattr(self.train_loader, "sampler")
                     and hasattr(self.train_loader.sampler, "set_epoch")
                     and hasattr(self.train_loader.sampler, "set_epoch")
                 ):
                 ):
@@ -1195,11 +1214,14 @@ class Trainer:
                 # CALCULATE PRECISE BATCHNORM STATS
                 # CALCULATE PRECISE BATCHNORM STATS
                 if self.precise_bn:
                 if self.precise_bn:
                     compute_precise_bn_stats(
                     compute_precise_bn_stats(
-                        model=self.net, loader=self.train_loader, precise_bn_batch_size=self.precise_bn_batch_size, num_gpus=self.num_devices
+                        model=self.net, loader=self.train_loader, precise_bn_batch_size=self.precise_bn_batch_size, num_gpus=get_world_size()
                     )
                     )
                     if self.ema:
                     if self.ema:
                         compute_precise_bn_stats(
                         compute_precise_bn_stats(
-                            model=self.ema_model.ema, loader=self.train_loader, precise_bn_batch_size=self.precise_bn_batch_size, num_gpus=self.num_devices
+                            model=self.ema_model.ema,
+                            loader=self.train_loader,
+                            precise_bn_batch_size=self.precise_bn_batch_size,
+                            num_gpus=get_world_size(),
                         )
                         )
 
 
                 # model switch - we replace self.net.module with the ema model for the testing and saving part
                 # model switch - we replace self.net.module with the ema model for the testing and saving part
@@ -1241,7 +1263,7 @@ class Trainer:
             logger.info("For HARD Termination - Stop the process again")
             logger.info("For HARD Termination - Stop the process again")
 
 
         finally:
         finally:
-            if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
+            if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
                 # CLEAN UP THE MULTI-GPU PROCESS GROUP WHEN DONE
                 # CLEAN UP THE MULTI-GPU PROCESS GROUP WHEN DONE
                 if torch.distributed.is_initialized():
                 if torch.distributed.is_initialized():
                     torch.distributed.destroy_process_group()
                     torch.distributed.destroy_process_group()
@@ -1293,8 +1315,8 @@ class Trainer:
         self.scaler = GradScaler(enabled=mixed_precision_enabled)
         self.scaler = GradScaler(enabled=mixed_precision_enabled)
 
 
         if mixed_precision_enabled:
         if mixed_precision_enabled:
-            assert self.device.startswith("cuda"), "mixed precision is not available for CPU"
-            if self.multi_gpu == MultiGPUMode.DATA_PARALLEL:
+            assert device_config.device.startswith("cuda"), "mixed precision is not available for CPU"
+            if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
                 # IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST.
                 # IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST.
                 # BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL,
                 # BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL,
                 # WE HAVE TO REGISTER THE WRAPPER BEFORE EVERY FORWARD CALL
                 # WE HAVE TO REGISTER THE WRAPPER BEFORE EVERY FORWARD CALL
@@ -1386,11 +1408,11 @@ class Trainer:
         if hasattr(self.net, "structure"):
         if hasattr(self.net, "structure"):
             self.architecture = self.net.structure
             self.architecture = self.net.structure
 
 
-        self.net.to(self.device)
+        self.net.to(device_config.device)
 
 
-        if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
+        if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
             logger.warning("Warning: distributed training is not supported in re_build_model()")
             logger.warning("Warning: distributed training is not supported in re_build_model()")
-        self.net = torch.nn.DataParallel(self.net, device_ids=self.device_ids) if self.multi_gpu else core_utils.WrappedModel(self.net)
+        self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids()) if device_config.multi_gpu else core_utils.WrappedModel(self.net)
 
 
     @property
     @property
     def get_module(self):
     def get_module(self):
@@ -1399,93 +1421,9 @@ class Trainer:
     def set_module(self, module):
     def set_module(self, module):
         self.net = module
         self.net = module
 
 
-    @resolve_param("requested_multi_gpu", TypeFactory(MultiGPUMode.dict()))
-    def _initialize_device(self, requested_device: str, requested_multi_gpu: Union[MultiGPUMode, str]):
-        """
-        _initialize_device - Initializes the device for the model - Default is CUDA
-            :param requested_device:        Device to initialize ('cuda' / 'cpu')
-            :param requested_multi_gpu:     Get Multiple GPU
-        """
-
-        # SELECT CUDA DEVICE
-        if requested_device == "cuda":
-            if torch.cuda.is_available():
-                self.device = "cuda"  # TODO - we may want to set the device number as well i.e. 'cuda:1'
-            else:
-                raise RuntimeError("CUDA DEVICE NOT FOUND... EXITING")
-
-        if require_gpu_setup(requested_multi_gpu):
-            raise GPUModeNotSetupError()
-
-        # SELECT CPU DEVICE
-        elif requested_device == "cpu":
-            self.device = "cpu"
-            self.multi_gpu = False
-        else:
-            # SELECT CUDA DEVICE BY DEFAULT IF AVAILABLE
-            self.device = "cuda" if torch.cuda.is_available() else "cpu"
-
-        # DEFUALT IS SET TO 1 - IT IS CHANGED IF MULTI-GPU IS USED
-        self.num_devices = 1
-
-        # IN CASE OF MULTIPLE GPUS UPDATE THE LEARNING AND DATA PARAMETERS
-        # FIXME - CREATE A DISCUSSION ON THESE PARAMETERS - WE MIGHT WANT TO CHANGE THE WAY WE USE THE LR AND
-        if requested_multi_gpu != MultiGPUMode.OFF:
-            if "cuda" in self.device:
-                # COLLECT THE AVAILABLE GPU AND COUNT THE AVAILABLE GPUS AMOUNT
-                self.device_ids = list(range(torch.cuda.device_count()))
-                self.num_devices = len(self.device_ids)
-                if self.num_devices == 1:
-                    self.multi_gpu = MultiGPUMode.OFF
-                    if requested_multi_gpu != MultiGPUMode.AUTO:
-                        # if AUTO mode was set - do not log a warning
-                        logger.warning("\n[WARNING] - Tried running on multiple GPU but only a single GPU is available\n")
-                else:
-                    if requested_multi_gpu == MultiGPUMode.AUTO:
-                        if ddp_utils.is_distributed():
-                            requested_multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
-                        else:
-                            requested_multi_gpu = MultiGPUMode.DATA_PARALLEL
-
-                    self.multi_gpu = requested_multi_gpu
-                    if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
-                        self._initialize_ddp()
-            else:
-                # MULTIPLE GPUS CAN BE ACTIVE ONLY IF A GPU IS AVAILABLE
-                self.multi_gpu = MultiGPUMode.OFF
-                logger.warning("\n[WARNING] - Tried running on multiple GPU but none are available => running on CPU\n")
-
-    def _initialize_ddp(self):
-        """
-        Initialize Distributed Data Parallel
-
-        Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
-        Whatever learning rate and schedule you specify will be applied to the each GPU individually.
-        Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
-        batch you specify times the number of GPUs. In the literature there are several "best practices" to set
-        learning rates and schedules for large batch sizes.
-        """
-        local_rank = ddp_utils.DDP_LOCAL_RANK
-        if local_rank > 0:
-            mute_current_process()
-
-        logger.info("Distributed training starting...")
-        if not torch.distributed.is_initialized():
-            backend = "gloo" if os.name == "nt" else "nccl"
-            torch.distributed.init_process_group(backend=backend, init_method="env://")
-
-        torch.cuda.set_device(local_rank)
-        self.device = "cuda:%d" % local_rank
-
-        # MAKE ALL HIGHER-RANK GPUS SILENT (DISTRIBUTED MODE)
-        self.ddp_silent_mode = local_rank > 0
-
-        if torch.distributed.get_rank() == 0:
-            logger.info(f"Training in distributed mode... with {str(torch.distributed.get_world_size())} GPUs")
-
     def _switch_device(self, new_device):
     def _switch_device(self, new_device):
-        self.device = new_device
-        self.net.to(self.device)
+        device_config.device = new_device
+        self.net.to(device_config.device)
 
 
     # FIXME - we need to resolve flake8's 'function is too complex' for this function
     # FIXME - we need to resolve flake8's 'function is too complex' for this function
     def _load_checkpoint_to_model(self):  # noqa: C901 - too complex
     def _load_checkpoint_to_model(self):  # noqa: C901 - too complex
@@ -1566,7 +1504,7 @@ class Trainer:
 
 
         # RESET METRIC RUNNERS
         # RESET METRIC RUNNERS
         self._reset_metrics()
         self._reset_metrics()
-        self.test_metrics.to(self.device)
+        self.test_metrics.to(device_config.device)
 
 
         if self.arch_params is None:
         if self.arch_params is None:
             self._init_arch_params()
             self._init_arch_params()
@@ -1629,8 +1567,8 @@ class Trainer:
         """
         """
         additional_log_items = {
         additional_log_items = {
             "initial_LR": self.training_params.initial_lr,
             "initial_LR": self.training_params.initial_lr,
-            "num_devices": self.num_devices,
-            "multi_gpu": str(self.multi_gpu),
+            "num_devices": get_world_size(),
+            "multi_gpu": str(device_config.multi_gpu),
             "device_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
             "device_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
         }
         }
         # ADD INSTALLED PACKAGE LIST + THEIR VERSIONS
         # ADD INSTALLED PACKAGE LIST + THEIR VERSIONS
@@ -1737,7 +1675,7 @@ class Trainer:
 
 
         self.net.eval()
         self.net.eval()
         self._reset_metrics()
         self._reset_metrics()
-        self.valid_metrics.to(self.device)
+        self.valid_metrics.to(device_config.device)
 
 
         return self.evaluate(
         return self.evaluate(
             data_loader=self.valid_loader, metrics=self.valid_metrics, evaluation_type=EvaluationType.VALIDATION, epoch=epoch, silent_mode=silent_mode
             data_loader=self.valid_loader, metrics=self.valid_metrics, evaluation_type=EvaluationType.VALIDATION, epoch=epoch, silent_mode=silent_mode
@@ -1778,7 +1716,7 @@ class Trainer:
             metrics_compute_fn=metrics,
             metrics_compute_fn=metrics,
             loss_avg_meter=loss_avg_meter,
             loss_avg_meter=loss_avg_meter,
             criterion=self.criterion,
             criterion=self.criterion,
-            device=self.device,
+            device=device_config.device,
             lr_warmup_epochs=lr_warmup_epochs,
             lr_warmup_epochs=lr_warmup_epochs,
             sg_logger=self.sg_logger,
             sg_logger=self.sg_logger,
             context_methods=self._get_context_methods(Phase.VALIDATION_BATCH_END),
             context_methods=self._get_context_methods(Phase.VALIDATION_BATCH_END),
@@ -1791,7 +1729,7 @@ class Trainer:
 
 
         with torch.no_grad():
         with torch.no_grad():
             for batch_idx, batch_items in enumerate(progress_bar_data_loader):
             for batch_idx, batch_items in enumerate(progress_bar_data_loader):
-                batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
+                batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True)
                 inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
                 inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)
 
 
                 output = self.net(inputs)
                 output = self.net(inputs)
@@ -1829,7 +1767,7 @@ class Trainer:
         #  DETECTIONMETRICS, WHICH ALREADY RETURN THE METRICS VALUEST HEMSELVES AND NOT THE ITEMS REQUIRED FOR SUCH
         #  DETECTIONMETRICS, WHICH ALREADY RETURN THE METRICS VALUEST HEMSELVES AND NOT THE ITEMS REQUIRED FOR SUCH
         #  COMPUTATION. ALSO REMOVE THE BELOW LINES BY IMPLEMENTING CRITERION AS A TORCHMETRIC.
         #  COMPUTATION. ALSO REMOVE THE BELOW LINES BY IMPLEMENTING CRITERION AS A TORCHMETRIC.
 
 
-        if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
+        if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
             logging_values = reduce_results_tuple_for_ddp(logging_values, next(self.net.parameters()).device)
             logging_values = reduce_results_tuple_for_ddp(logging_values, next(self.net.parameters()).device)
 
 
         pbar_message_dict = get_train_loop_description_dict(logging_values, metrics, self.loss_logging_items_names)
         pbar_message_dict = get_train_loop_description_dict(logging_values, metrics, self.loss_logging_items_names)
Discard
@@ -1,5 +1,7 @@
 import sys
 import sys
+import os
 import itertools
 import itertools
+from typing import List, Tuple
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
 import torch
 import torch
@@ -10,12 +12,18 @@ from torch.distributed.elastic.multiprocessing import Std
 from torch.distributed.elastic.multiprocessing.errors import record
 from torch.distributed.elastic.multiprocessing.errors import record
 from torch.distributed.launcher.api import LaunchConfig, elastic_launch
 from torch.distributed.launcher.api import LaunchConfig, elastic_launch
 
 
+from super_gradients.common.environment.ddp_utils import init_trainer
 from super_gradients.common.data_types.enum import MultiGPUMode
 from super_gradients.common.data_types.enum import MultiGPUMode
 from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
 from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
-from super_gradients.common.environment.ddp_utils import find_free_port, is_distributed
-from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.environment.ddp_utils import find_free_port, is_distributed, is_launched_using_sg
 
 
 
 
+from super_gradients.common.abstractions.abstract_logger import get_logger, mute_current_process
+from super_gradients.common.environment.device_utils import device_config
+
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories.type_factory import TypeFactory
+
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -142,6 +150,14 @@ def get_local_rank():
     return dist.get_rank() if dist.is_initialized() else 0
     return dist.get_rank() if dist.is_initialized() else 0
 
 
 
 
+def require_ddp_setup() -> bool:
+    return device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and device_config.assigned_rank != get_local_rank()
+
+
+def is_ddp_subprocess():
+    return torch.distributed.get_rank() > 0 if dist.is_initialized() else False
+
+
 def get_world_size() -> int:
 def get_world_size() -> int:
     """
     """
     Returns the world size if running in DDP, and 1 otherwise
     Returns the world size if running in DDP, and 1 otherwise
@@ -154,6 +170,14 @@ def get_world_size() -> int:
     return dist.get_world_size()
     return dist.get_world_size()
 
 
 
 
+def get_device_ids() -> List[int]:
+    return list(range(get_world_size()))
+
+
+def count_used_devices() -> int:
+    return len(get_device_ids())
+
+
 @contextmanager
 @contextmanager
 def wait_for_the_master(local_rank: int):
 def wait_for_the_master(local_rank: int):
     """
     """
@@ -171,33 +195,145 @@ def wait_for_the_master(local_rank: int):
             dist.barrier()
             dist.barrier()
 
 
 
 
-def setup_device(multi_gpu: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
+def setup_gpu_mode(gpu_mode: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
+    """[DEPRECATED in favor of setup_device] If required, launch ddp subprocesses.
+    :param gpu_mode:    DDP, DP, Off or AUTO
+    :param num_gpus:    Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
+    """
+    logger.warning("setup_gpu_mode is now deprecated in favor of setup_device")
+    setup_device(multi_gpu=gpu_mode, num_gpus=num_gpus)
+
+
+@resolve_param("multi_gpu", TypeFactory(MultiGPUMode.dict()))
+def setup_device(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None, device: str = "cuda"):
     """
     """
     If required, launch ddp subprocesses.
     If required, launch ddp subprocesses.
-    :param multi_gpu:   DDP, DP or Off
-    :param num_gpus:    Number of GPU's to use.
+    :param multi_gpu:    DDP, DP, Off or AUTO
+    :param num_gpus:     Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
+    """
+    init_trainer()
+
+    # When launching with torch.distributed.launch or torchrun, multi_gpu might not be set to DDP (since we are not using the recipe params)
+    # To avoid any issue we force multi_gpu to be DDP if the current process is ddp subprocess. We also set num_gpus, device to run smoothly.
+    if not is_launched_using_sg() and is_distributed():
+        multi_gpu, num_gpus, device = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, None, "cuda"
+
+    if device is None:
+        device = "cuda"
+
+    if device == "cuda" and not torch.cuda.is_available():
+        logger.warning("CUDA device is not available on your device... Moving to CPU.")
+        device = "cpu"
+
+    if device == "cpu":
+        setup_cpu(multi_gpu, num_gpus)
+    elif device == "cuda":
+        setup_gpu(multi_gpu, num_gpus)
+    else:
+        raise ValueError(f"Only valid values for device are: 'cpu' and 'cuda'. Received: '{device}'")
+
+
+def setup_cpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
     """
     """
-    if multi_gpu == MultiGPUMode.AUTO and torch.cuda.device_count() > 1:
-        multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
-    if require_gpu_setup(multi_gpu):
-        num_gpus = num_gpus or torch.cuda.device_count()
+    :param multi_gpu:    DDP, DP, Off or AUTO
+    :param num_gpus:     Number of GPU's to use.
+    """
+    if multi_gpu not in (MultiGPUMode.OFF, MultiGPUMode.AUTO):
+        raise ValueError(f"device='cpu' and multi_gpu={multi_gpu} are not compatible together.")
+
+    if num_gpus not in (0, None):
+        raise ValueError(f"device='cpu' and num_gpus={num_gpus} are not compatible together.")
+
+    device_config.device = "cpu"
+    device_config.multi_gpu = MultiGPUMode.OFF
+
+
+def setup_gpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
+    """
+    If required, launch ddp subprocesses.
+    :param multi_gpu:    DDP, DP, Off or AUTO
+    :param num_gpus:     Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
+    """
+
+    if num_gpus == 0:
+        raise ValueError("device='cuda' and num_gpus=0 are not compatible together.")
+
+    multi_gpu, num_gpus = _resolve_gpu_params(multi_gpu=multi_gpu, num_gpus=num_gpus)
+
+    device_config.device = "cuda"
+    device_config.multi_gpu = multi_gpu
+
+    if is_distributed():
+        initialize_ddp()
+    elif multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
+        restart_script_with_ddp(num_gpus=num_gpus)
+
+
+def _resolve_gpu_params(multi_gpu: MultiGPUMode, num_gpus: int) -> Tuple[MultiGPUMode, int]:
+    """
+    Resolve the values multi_gpu in (None, MultiGPUMode.AUTO) and num_gpus in (None, -1), and check compatibility between both parameters.
+    :param multi_gpu:    DDP, DP, Off or AUTO
+    :param num_gpus:     Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
+    """
+
+    # Resolve None
+    if multi_gpu is None:
+        if num_gpus is None:  # When Nothing is specified, just run on single GPU
+            multi_gpu = MultiGPUMode.OFF
+            num_gpus = 1
+        else:
+            multi_gpu = MultiGPUMode.AUTO
+
+    if num_gpus is None:
+        num_gpus = -1
+
+    # Resolve multi_gpu
+    if num_gpus == -1:
+        if multi_gpu in (MultiGPUMode.OFF, MultiGPUMode.DATA_PARALLEL):
+            num_gpus = 1
+        elif multi_gpu in (MultiGPUMode.AUTO, MultiGPUMode.DISTRIBUTED_DATA_PARALLEL):
+            num_gpus = torch.cuda.device_count()
+
+    # Resolve multi_gpu
+    if multi_gpu == MultiGPUMode.AUTO:
+        if num_gpus > 1:
+            multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
+        else:
+            multi_gpu = MultiGPUMode.OFF
+
+    # Check compatibility between num_gpus and multi_gpu
+    if multi_gpu in (MultiGPUMode.OFF, MultiGPUMode.DATA_PARALLEL):
+        if num_gpus != 1:
+            raise ValueError(f"You specified num_gpus={num_gpus} but it has not be 1 on when working with multi_gpu={multi_gpu}")
+    else:
         if num_gpus > torch.cuda.device_count():
         if num_gpus > torch.cuda.device_count():
             raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
             raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
-        restart_script_with_ddp(num_gpus)
+    return multi_gpu, num_gpus
 
 
 
 
-def setup_gpu_mode(gpu_mode: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
-    """If required, launch ddp subprocesses (deprecated).
-    :param gpu_mode:    DDP, DP or Off
-    :param num_gpus:    Number of GPU's to use.
+def initialize_ddp():
     """
     """
-    logger.warning("setup_gpu_mode is now deprecated in favor of setup_device. This will be removed in next version")
-    setup_device(multi_gpu=gpu_mode, num_gpus=num_gpus)
+    Initialize Distributed Data Parallel
+
+    Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
+    Whatever learning rate and schedule you specify will be applied to the each GPU individually.
+    Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
+    batch you specify times the number of GPUs. In the literature there are several "best practices" to set
+    learning rates and schedules for large batch sizes.
+    """
+
+    if device_config.assigned_rank > 0:
+        mute_current_process()
 
 
+    logger.info("Distributed training starting...")
+    if not torch.distributed.is_initialized():
+        backend = "gloo" if os.name == "nt" else "nccl"
+        torch.distributed.init_process_group(backend=backend, init_method="env://")
+    torch.cuda.set_device(device_config.assigned_rank)
 
 
-def require_gpu_setup(multi_gpu: MultiGPUMode) -> bool:
-    """Check if the environment requires a setup in order to work with DDP."""
-    return (multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL) and (not is_distributed())
+    if torch.distributed.get_rank() == 0:
+        logger.info(f"Training in distributed mode... with {str(torch.distributed.get_world_size())} GPUs")
+    device_config.device = "cuda:%d" % device_config.assigned_rank
 
 
 
 
 @record
 @record
@@ -209,7 +345,7 @@ def restart_script_with_ddp(num_gpus: int = None):
     ddp_port = find_free_port()
     ddp_port = find_free_port()
 
 
     # Get the value fom recipe if specified, otherwise take all available devices.
     # Get the value fom recipe if specified, otherwise take all available devices.
-    num_gpus = num_gpus if num_gpus else torch.cuda.device_count()
+    num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
     if num_gpus > torch.cuda.device_count():
     if num_gpus > torch.cuda.device_count():
         raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
         raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
 
 
@@ -253,3 +389,22 @@ def get_gpu_mem_utilization():
         return torch.cuda.memory_reserved()
         return torch.cuda.memory_reserved()
     else:
     else:
         return torch.cuda.memory_cached()
         return torch.cuda.memory_cached()
+
+
+class DDPNotSetupException(Exception):
+    """
+    Exception raised when DDP setup is required but was not done
+
+    Attributes:
+        message -- explanation of the error
+    """
+
+    def __init__(self):
+        self.message = (
+            "Your environment was not setup correctly for DDP.\n"
+            "Please run at the beginning of your script:\n"
+            ">>> from super_gradients.training.utils.distributed_training_utils import setup_device'\n"
+            ">>> from super_gradients.common.data_types.enum import MultiGPUMode\n"
+            ">>> setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)"
+        )
+        super().__init__(self.message)
Discard
@@ -17,6 +17,7 @@ import torch
 
 
 from torch.utils.tensorboard import SummaryWriter
 from torch.utils.tensorboard import SummaryWriter
 
 
+from super_gradients.common.environment.device_utils import device_config
 from super_gradients.training.exceptions.dataset_exceptions import UnsupportedBatchItemsFormat
 from super_gradients.training.exceptions.dataset_exceptions import UnsupportedBatchItemsFormat
 from super_gradients.common.data_types.enum import MultiGPUMode
 from super_gradients.common.data_types.enum import MultiGPUMode
 
 
@@ -453,7 +454,7 @@ def log_main_training_params(multi_gpu: MultiGPUMode, num_gpus: int, batch_size:
     msg = (
     msg = (
         "TRAINING PARAMETERS:\n"
         "TRAINING PARAMETERS:\n"
         f"    - Mode:                         {multi_gpu.name if multi_gpu else 'Single GPU'}\n"
         f"    - Mode:                         {multi_gpu.name if multi_gpu else 'Single GPU'}\n"
-        f"    - Number of GPUs:               {num_gpus:<10} ({torch.cuda.device_count()} available on the machine)\n"
+        f"    - Number of GPUs:               {num_gpus if 'cuda' in device_config.device  else 0:<10} ({torch.cuda.device_count()} available on the machine)\n"
         f"    - Dataset size:                 {len_train_set:<10} (len(train_set))\n"
         f"    - Dataset size:                 {len_train_set:<10} (len(train_set))\n"
         f"    - Batch size per GPU:           {batch_size:<10} (batch_size)\n"
         f"    - Batch size per GPU:           {batch_size:<10} (batch_size)\n"
         f"    - Batch Accumulate:             {batch_accumulate:<10} (batch_accumulate)\n"
         f"    - Batch Accumulate:             {batch_accumulate:<10} (batch_accumulate)\n"
Discard
@@ -1,4 +1,4 @@
-from super_gradients.training import MultiGPUMode, models
+from super_gradients.training import models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
@@ -21,7 +21,7 @@ class CallWrapper:
 
 
 class EMAIntegrationTest(unittest.TestCase):
 class EMAIntegrationTest(unittest.TestCase):
     def _init_model(self) -> None:
     def _init_model(self) -> None:
-        self.trainer = Trainer("resnet18_cifar_ema_test", device="cpu", multi_gpu=MultiGPUMode.OFF)
+        self.trainer = Trainer("resnet18_cifar_ema_test")
         self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
         self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
 
 
     @classmethod
     @classmethod
Discard
@@ -1,6 +1,5 @@
 import unittest
 import unittest
 
 
-from super_gradients.training import MultiGPUMode
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders import imagenet_val, imagenet_vit_base_val
 from super_gradients.training.dataloaders import imagenet_val, imagenet_vit_base_val
 from super_gradients.training.dataloaders.dataloaders import (
 from super_gradients.training.dataloaders.dataloaders import (
@@ -228,13 +227,13 @@ class PretrainedModelsTest(unittest.TestCase):
         }
         }
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_resnet50", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_resnet50")
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
 
 
     def test_transfer_learning_resnet50_imagenet(self):
     def test_transfer_learning_resnet50_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_resnet50_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_resnet50_transfer_learning")
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -244,14 +243,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_resnet34_imagenet(self):
     def test_pretrained_resnet34_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_resnet34", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_resnet34")
 
 
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
 
 
     def test_transfer_learning_resnet34_imagenet(self):
     def test_transfer_learning_resnet34_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_resnet34_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_resnet34_transfer_learning")
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -261,14 +260,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_resnet18_imagenet(self):
     def test_pretrained_resnet18_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_resnet18", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_resnet18")
 
 
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
 
 
     def test_transfer_learning_resnet18_imagenet(self):
     def test_transfer_learning_resnet18_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_resnet18_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_resnet18_transfer_learning")
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -278,14 +277,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY800", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY800")
 
 
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
 
 
     def test_transfer_learning_regnetY800_imagenet(self):
     def test_transfer_learning_regnetY800_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY800_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY800_transfer_learning")
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -295,14 +294,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_regnetY600_imagenet(self):
     def test_pretrained_regnetY600_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY600", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY600")
 
 
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
 
 
     def test_transfer_learning_regnetY600_imagenet(self):
     def test_transfer_learning_regnetY600_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY600_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY600_transfer_learning")
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -312,14 +311,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_regnetY400_imagenet(self):
     def test_pretrained_regnetY400_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY400", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY400")
 
 
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
 
 
     def test_transfer_learning_regnetY400_imagenet(self):
     def test_transfer_learning_regnetY400_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY400_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY400_transfer_learning")
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -329,14 +328,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_regnetY200_imagenet(self):
     def test_pretrained_regnetY200_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY200", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY200")
 
 
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
 
 
     def test_transfer_learning_regnetY200_imagenet(self):
     def test_transfer_learning_regnetY200_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY200_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY200_transfer_learning")
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -346,14 +345,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_repvgg_a0", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_repvgg_a0")
 
 
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
 
 
     def test_transfer_learning_repvgg_a0_imagenet(self):
     def test_transfer_learning_repvgg_a0_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_repvgg_a0_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_repvgg_a0_transfer_learning")
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -363,7 +362,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_regseg48_cityscapes(self):
     def test_pretrained_regseg48_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_regseg48", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_regseg48")
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -375,7 +374,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
 
 
     def test_transfer_learning_regseg48_cityscapes(self):
     def test_transfer_learning_regseg48_cityscapes(self):
-        trainer = Trainer("regseg48_cityscapes_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("regseg48_cityscapes_transfer_learning")
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"], **self.cityscapes_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -385,7 +384,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_ddrnet23_cityscapes(self):
     def test_pretrained_ddrnet23_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_ddrnet23", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_ddrnet23")
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -397,7 +396,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
 
 
     def test_pretrained_ddrnet23_slim_cityscapes(self):
     def test_pretrained_ddrnet23_slim_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_ddrnet23_slim", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_ddrnet23_slim")
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -409,7 +408,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
 
 
     def test_transfer_learning_ddrnet23_cityscapes(self):
     def test_transfer_learning_ddrnet23_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_ddrnet23_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_ddrnet23_transfer_learning")
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -419,7 +418,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_ddrnet23_slim_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_ddrnet23_slim_transfer_learning")
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"], **self.cityscapes_pretrained_ckpt_params)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -429,7 +428,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
-        trainer = Trainer("coco_segmentation_subclass_pretrained_shelfnet34_lw", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("coco_segmentation_subclass_pretrained_shelfnet34_lw")
         model = models.get(
         model = models.get(
             "shelfnet34_lw",
             "shelfnet34_lw",
             arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
             arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
@@ -439,14 +438,14 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
 
 
     def test_pretrained_efficientnet_b0_imagenet(self):
     def test_pretrained_efficientnet_b0_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_efficientnet_b0", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_efficientnet_b0")
 
 
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
 
 
     def test_transfer_learning_efficientnet_b0_imagenet(self):
     def test_transfer_learning_efficientnet_b0_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_efficientnet_b0_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_efficientnet_b0_transfer_learning")
 
 
         model = models.get(
         model = models.get(
             "efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params, num_classes=5
             "efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"], **self.imagenet_pretrained_ckpt_params, num_classes=5
@@ -459,7 +458,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
-        trainer = Trainer("coco_ssd_lite_mobilenet_v2", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("coco_ssd_lite_mobilenet_v2")
         model = models.get("ssd_lite_mobilenet_v2", arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"], **self.coco_pretrained_ckpt_params)
         model = models.get("ssd_lite_mobilenet_v2", arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"], **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(
         res = trainer.test(
@@ -471,7 +470,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["ssd_lite_mobilenet_v2"], delta=0.001)
 
 
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
-        trainer = Trainer("coco_ssd_lite_mobilenet_v2_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("coco_ssd_lite_mobilenet_v2_transfer_learning")
         transfer_arch_params = self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"].copy()
         transfer_arch_params = self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"].copy()
         transfer_arch_params["num_classes"] = 5
         transfer_arch_params["num_classes"] = 5
         model = models.get("ssd_lite_mobilenet_v2", arch_params=transfer_arch_params, **self.coco_pretrained_ckpt_params)
         model = models.get("ssd_lite_mobilenet_v2", arch_params=transfer_arch_params, **self.coco_pretrained_ckpt_params)
@@ -483,7 +482,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_ssd_mobilenet_v1_coco(self):
     def test_pretrained_ssd_mobilenet_v1_coco(self):
-        trainer = Trainer("coco_ssd_mobilenet_v1", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("coco_ssd_mobilenet_v1")
         model = models.get("ssd_mobilenet_v1", arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"], **self.coco_pretrained_ckpt_params)
         model = models.get("ssd_mobilenet_v1", arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"], **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(
         res = trainer.test(
@@ -495,7 +494,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
 
 
     def test_pretrained_yolox_s_coco(self):
     def test_pretrained_yolox_s_coco(self):
-        trainer = Trainer("yolox_s", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("yolox_s")
 
 
         model = models.get("yolox_s", **self.coco_pretrained_ckpt_params)
         model = models.get("yolox_s", **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
@@ -506,7 +505,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
 
 
     def test_pretrained_yolox_m_coco(self):
     def test_pretrained_yolox_m_coco(self):
-        trainer = Trainer("yolox_m", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("yolox_m")
         model = models.get("yolox_m", **self.coco_pretrained_ckpt_params)
         model = models.get("yolox_m", **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
@@ -516,7 +515,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
 
 
     def test_pretrained_yolox_l_coco(self):
     def test_pretrained_yolox_l_coco(self):
-        trainer = Trainer("yolox_l", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("yolox_l")
         model = models.get("yolox_l", **self.coco_pretrained_ckpt_params)
         model = models.get("yolox_l", **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
@@ -526,7 +525,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
 
 
     def test_pretrained_yolox_n_coco(self):
     def test_pretrained_yolox_n_coco(self):
-        trainer = Trainer("yolox_n", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("yolox_n")
 
 
         model = models.get("yolox_n", **self.coco_pretrained_ckpt_params)
         model = models.get("yolox_n", **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
@@ -537,7 +536,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
 
 
     def test_pretrained_yolox_t_coco(self):
     def test_pretrained_yolox_t_coco(self):
-        trainer = Trainer("yolox_t", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("yolox_t")
         model = models.get("yolox_t", **self.coco_pretrained_ckpt_params)
         model = models.get("yolox_t", **self.coco_pretrained_ckpt_params)
         res = trainer.test(
         res = trainer.test(
             model=model,
             model=model,
@@ -547,7 +546,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_t"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_t"], delta=0.001)
 
 
     def test_transfer_learning_yolox_n_coco(self):
     def test_transfer_learning_yolox_n_coco(self):
-        trainer = Trainer("test_transfer_learning_yolox_n_coco", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("test_transfer_learning_yolox_n_coco")
         model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
         model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -557,7 +556,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_mobilenet_v3_large_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_mobilenet_v3_large_transfer_learning")
 
 
         model = models.get(
         model = models.get(
             "mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
             "mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
@@ -570,14 +569,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_mobilenet_v3_large_imagenet(self):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
-        trainer = Trainer("imagenet_mobilenet_v3_large", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_mobilenet_v3_large")
 
 
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_large"], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_mobilenet_v3_small_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_mobilenet_v3_small_transfer_learning")
 
 
         model = models.get(
         model = models.get(
             "mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
             "mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5
@@ -590,14 +589,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_mobilenet_v3_small_imagenet(self):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
-        trainer = Trainer("imagenet_mobilenet_v3_small", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_mobilenet_v3_small")
 
 
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v3_small"], delta=0.001)
 
 
     def test_transfer_learning_mobilenet_v2_imagenet(self):
     def test_transfer_learning_mobilenet_v2_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_mobilenet_v2_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_mobilenet_v2_transfer_learning")
 
 
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
@@ -608,14 +607,14 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_mobilenet_v2_imagenet(self):
     def test_pretrained_mobilenet_v2_imagenet(self):
-        trainer = Trainer("imagenet_mobilenet_v2", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_mobilenet_v2")
 
 
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"], **self.imagenet_pretrained_ckpt_params)
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         res = trainer.test(model=model, test_loader=self.imagenet_dataset, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0].cpu().item()
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
 
 
     def test_pretrained_stdc1_seg50_cityscapes(self):
     def test_pretrained_stdc1_seg50_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc1_seg50", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc1_seg50")
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -630,7 +629,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc1_seg50_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc1_seg50_transfer_learning")
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -640,7 +639,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_stdc1_seg75_cityscapes(self):
     def test_pretrained_stdc1_seg75_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc1_seg75", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc1_seg75")
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -655,7 +654,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc1_seg75_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc1_seg75_transfer_learning")
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -665,7 +664,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_stdc2_seg50_cityscapes(self):
     def test_pretrained_stdc2_seg50_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc2_seg50", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc2_seg50")
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -680,7 +679,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc2_seg50_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc2_seg50_transfer_learning")
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -690,7 +689,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_stdc2_seg75_cityscapes(self):
     def test_pretrained_stdc2_seg75_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc2_seg75", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc2_seg75")
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(
             trainer.test(
@@ -705,7 +704,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_stdc2_seg75_transfer_learning", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_stdc2_seg75_transfer_learning")
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"], **self.cityscapes_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model,
             model=model,
@@ -715,7 +714,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_transfer_learning_vit_base_imagenet21k(self):
     def test_transfer_learning_vit_base_imagenet21k(self):
-        trainer = Trainer("imagenet21k_pretrained_vit_base", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet21k_pretrained_vit_base")
 
 
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
@@ -726,7 +725,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_transfer_learning_vit_large_imagenet21k(self):
     def test_transfer_learning_vit_large_imagenet21k(self):
-        trainer = Trainer("imagenet21k_pretrained_vit_large", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet21k_pretrained_vit_large")
 
 
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet21k_pretrained_ckpt_params, num_classes=5)
         trainer.train(
         trainer.train(
@@ -737,7 +736,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_vit_base_imagenet(self):
     def test_pretrained_vit_base_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_vit_base", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_vit_base")
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
@@ -747,7 +746,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
 
 
     def test_pretrained_vit_large_imagenet(self):
     def test_pretrained_vit_large_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_vit_large", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_vit_large")
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
@@ -757,7 +756,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
 
 
     def test_pretrained_beit_base_imagenet(self):
     def test_pretrained_beit_base_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_beit_base", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_beit_base")
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params)
         res = (
         res = (
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
             trainer.test(model=model, test_loader=self.imagenet_dataset_05_mean_std, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)[0]
@@ -767,7 +766,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["beit_base_patch16_224"], delta=0.001)
 
 
     def test_transfer_learning_beit_base_imagenet(self):
     def test_transfer_learning_beit_base_imagenet(self):
-        trainer = Trainer("test_transfer_learning_beit_base_imagenet", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("test_transfer_learning_beit_base_imagenet")
 
 
         model = models.get(
         model = models.get(
             "beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params, num_classes=5
             "beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"], **self.imagenet_pretrained_ckpt_params, num_classes=5
@@ -780,7 +779,7 @@ class PretrainedModelsTest(unittest.TestCase):
         )
         )
 
 
     def test_pretrained_pplite_t_seg50_cityscapes(self):
     def test_pretrained_pplite_t_seg50_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_pplite_t_seg50", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_pplite_t_seg50")
         model = models.get("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
@@ -796,7 +795,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
 
 
     def test_pretrained_pplite_t_seg75_cityscapes(self):
     def test_pretrained_pplite_t_seg75_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_pplite_t_seg75", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_pplite_t_seg75")
         model = models.get("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
@@ -812,7 +811,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg50_cityscapes(self):
     def test_pretrained_pplite_b_seg50_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_pplite_b_seg50", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_pplite_b_seg50")
         model = models.get("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
@@ -828,7 +827,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg75_cityscapes(self):
     def test_pretrained_pplite_b_seg75_cityscapes(self):
-        trainer = Trainer("cityscapes_pretrained_pplite_b_seg75", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("cityscapes_pretrained_pplite_b_seg75")
         model = models.get("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
         model = models.get("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"], **self.cityscapes_pretrained_ckpt_params)
 
 
         res = (
         res = (
Discard
@@ -13,7 +13,7 @@ from super_gradients.training.losses.kd_losses import KDLogitsLoss
 class KDEMATest(unittest.TestCase):
 class KDEMATest(unittest.TestCase):
     @classmethod
     @classmethod
     def setUp(cls):
     def setUp(cls):
-        cls.sg_trained_teacher = Trainer("sg_trained_teacher", device="cpu")
+        cls.sg_trained_teacher = Trainer("sg_trained_teacher")
 
 
         cls.kd_train_params = {
         cls.kd_train_params = {
             "max_epochs": 3,
             "max_epochs": 3,
@@ -38,7 +38,7 @@ class KDEMATest(unittest.TestCase):
     def test_teacher_ema_not_duplicated(self):
     def test_teacher_ema_not_duplicated(self):
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
         """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
 
 
-        kd_model = KDTrainer("test_teacher_ema_not_duplicated", device="cpu")
+        kd_model = KDTrainer("test_teacher_ema_not_duplicated")
         student = models.get("resnet18", arch_params={"num_classes": 1000})
         student = models.get("resnet18", arch_params={"num_classes": 1000})
         teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
@@ -58,7 +58,7 @@ class KDEMATest(unittest.TestCase):
 
 
         # Create a KD trainer and train it
         # Create a KD trainer and train it
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
-        kd_model = KDTrainer("test_kd_ema_ckpt_reload", device="cpu")
+        kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         student = models.get("resnet18", arch_params={"num_classes": 1000})
         student = models.get("resnet18", arch_params={"num_classes": 1000})
         teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
@@ -73,7 +73,7 @@ class KDEMATest(unittest.TestCase):
         net = kd_model.net
         net = kd_model.net
 
 
         # Load the trained KD trainer
         # Load the trained KD trainer
-        kd_model = KDTrainer("test_kd_ema_ckpt_reload", device="cpu")
+        kd_model = KDTrainer("test_kd_ema_ckpt_reload")
         student = models.get("resnet18", arch_params={"num_classes": 1000})
         student = models.get("resnet18", arch_params={"num_classes": 1000})
         teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
 
 
Discard
@@ -67,7 +67,7 @@ class KDTrainerTest(unittest.TestCase):
         self.assertTrue(initial_param_groups[0]["lr"] == 0.2 == updated_param_groups[0]["lr"])
         self.assertTrue(initial_param_groups[0]["lr"] == 0.2 == updated_param_groups[0]["lr"])
 
 
     def test_train_kd_module_external_models(self):
     def test_train_kd_module_external_models(self):
-        sg_model = KDTrainer("test_train_kd_module_external_models", device="cpu")
+        sg_model = KDTrainer("test_train_kd_module_external_models")
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         teacher_model = ResNet50(arch_params={}, num_classes=5)
         student_model = ResNet18(arch_params={}, num_classes=5)
         student_model = ResNet18(arch_params={}, num_classes=5)
 
 
@@ -86,7 +86,7 @@ class KDTrainerTest(unittest.TestCase):
         self.assertFalse(check_models_have_same_weights(student_model, sg_model.net.module.student))
         self.assertFalse(check_models_have_same_weights(student_model, sg_model.net.module.student))
 
 
     def test_train_model_with_input_adapter(self):
     def test_train_model_with_input_adapter(self):
-        kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter", device="cpu")
+        kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
         student = models.get("resnet18", arch_params={"num_classes": 5})
         student = models.get("resnet18", arch_params={"num_classes": 5})
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
@@ -107,7 +107,7 @@ class KDTrainerTest(unittest.TestCase):
         self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
         self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
 
 
     def test_load_ckpt_best_for_student(self):
     def test_load_ckpt_best_for_student(self):
-        kd_trainer = KDTrainer("test_load_ckpt_best", device="cpu")
+        kd_trainer = KDTrainer("test_load_ckpt_best")
         student = models.get("resnet18", arch_params={"num_classes": 5})
         student = models.get("resnet18", arch_params={"num_classes": 5})
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
@@ -126,7 +126,7 @@ class KDTrainerTest(unittest.TestCase):
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
 
 
     def test_load_ckpt_best_for_student_with_ema(self):
     def test_load_ckpt_best_for_student_with_ema(self):
-        kd_trainer = KDTrainer("test_load_ckpt_best", device="cpu")
+        kd_trainer = KDTrainer("test_load_ckpt_best")
         student = models.get("resnet18", arch_params={"num_classes": 5})
         student = models.get("resnet18", arch_params={"num_classes": 5})
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
@@ -146,7 +146,7 @@ class KDTrainerTest(unittest.TestCase):
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
         self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
 
 
     def test_resume_kd_training(self):
     def test_resume_kd_training(self):
-        kd_trainer = KDTrainer("test_resume_training_start", device="cpu")
+        kd_trainer = KDTrainer("test_resume_training_start")
         student = models.get("resnet18", arch_params={"num_classes": 5})
         student = models.get("resnet18", arch_params={"num_classes": 5})
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
@@ -160,7 +160,7 @@ class KDTrainerTest(unittest.TestCase):
         )
         )
         latest_net = deepcopy(kd_trainer.net)
         latest_net = deepcopy(kd_trainer.net)
 
 
-        kd_trainer = KDTrainer("test_resume_training_start", device="cpu")
+        kd_trainer = KDTrainer("test_resume_training_start")
         student = models.get("resnet18", arch_params={"num_classes": 5})
         student = models.get("resnet18", arch_params={"num_classes": 5})
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
         teacher = models.get("resnet50", arch_params={"num_classes": 5}, pretrained_weights="imagenet")
 
 
Discard
@@ -1,6 +1,6 @@
 import unittest
 import unittest
 import super_gradients
 import super_gradients
-from super_gradients.training import MultiGPUMode, models
+from super_gradients.training import models
 from super_gradients.training import Trainer
 from super_gradients.training import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
@@ -14,17 +14,17 @@ class PretrainedModelsUnitTest(unittest.TestCase):
         self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
         self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_resnet50_unit_test", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_resnet50_unit_test")
         model = models.get("resnet50", pretrained_weights="imagenet")
         model = models.get("resnet50", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_regnetY800_unit_test", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_regnetY800_unit_test")
         model = models.get("regnetY800", pretrained_weights="imagenet")
         model = models.get("regnetY800", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test", multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer("imagenet_pretrained_repvgg_a0_unit_test")
         model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
 
 
Discard
@@ -4,7 +4,7 @@ import os
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader, detection_test_dataloader, segmentation_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader, detection_test_dataloader, segmentation_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
-from super_gradients.training import MultiGPUMode, models
+from super_gradients.training import models
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
@@ -31,7 +31,7 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=""):
     def get_detection_trainer(name=""):
-        trainer = Trainer(name, multi_gpu=MultiGPUMode.OFF)
+        trainer = Trainer(name)
         model = models.get("yolox_s", num_classes=5)
         model = models.get("yolox_s", num_classes=5)
         return trainer, model
         return trainer, model
 
 
Discard
@@ -30,7 +30,7 @@ class TestViT(unittest.TestCase):
         """
         """
         Validate vit_base
         Validate vit_base
         """
         """
-        trainer = Trainer("test_vit_base", device="cpu")
+        trainer = Trainer("test_vit_base")
         model = models.get("vit_base", arch_params=self.arch_params, num_classes=5)
         model = models.get("vit_base", arch_params=self.arch_params, num_classes=5)
         trainer.train(
         trainer.train(
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
             model=model, training_params=self.train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
Discard