Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
49 changed files with 176 additions and 246 deletions
  1. 0
    1
      src/super_gradients/examples/resnet_qat/resnet_qat_example.py
  2. 1
    1
      src/super_gradients/recipes/cifar10_resnet.yaml
  3. 1
    1
      src/super_gradients/recipes/cityscapes_ddrnet.yaml
  4. 1
    1
      src/super_gradients/recipes/cityscapes_regseg48.yaml
  5. 1
    1
      src/super_gradients/recipes/cityscapes_stdc_base.yaml
  6. 1
    1
      src/super_gradients/recipes/coco2017_ssd_lite_mobilenet_v2.yaml
  7. 1
    1
      src/super_gradients/recipes/coco2017_yolox.yaml
  8. 1
    1
      src/super_gradients/recipes/coco_segmentation_shelfnet_lw.yaml
  9. 1
    1
      src/super_gradients/recipes/imagenet_efficientnet.yaml
  10. 1
    1
      src/super_gradients/recipes/imagenet_mobilenetv2.yaml
  11. 1
    1
      src/super_gradients/recipes/imagenet_mobilenetv3_base.yaml
  12. 1
    1
      src/super_gradients/recipes/imagenet_regnetY.yaml
  13. 1
    1
      src/super_gradients/recipes/imagenet_repvgg.yaml
  14. 1
    1
      src/super_gradients/recipes/imagenet_resnet50.yaml
  15. 1
    1
      src/super_gradients/recipes/imagenet_resnet50_kd.yaml
  16. 1
    1
      src/super_gradients/recipes/imagenet_vit_base.yaml
  17. 2
    0
      src/super_gradients/recipes/training_hyperparams/default_train_params.yaml
  18. 3
    11
      src/super_gradients/training/kd_trainer/kd_trainer.py
  19. 1
    0
      src/super_gradients/training/params.py
  20. 40
    57
      src/super_gradients/training/sg_trainer/sg_trainer.py
  21. 2
    22
      src/super_gradients/training/utils/checkpoint_utils.py
  22. 0
    2
      src/super_gradients/training/utils/weight_averaging_utils.py
  23. 2
    2
      tests/end_to_end_tests/cifar_trainer_test.py
  24. 1
    1
      tests/end_to_end_tests/trainer_test.py
  25. 2
    2
      tests/integration_tests/conversion_callback_test.py
  26. 1
    1
      tests/integration_tests/deci_lab_export_test.py
  27. 1
    1
      tests/integration_tests/ema_train_integration_test.py
  28. 1
    1
      tests/integration_tests/lr_test.py
  29. 59
    60
      tests/integration_tests/pretrained_models_test.py
  30. 1
    1
      tests/integration_tests/qat_integration_test.py
  31. 1
    3
      tests/unit_tests/dataset_statistics_test.py
  32. 3
    4
      tests/unit_tests/detection_utils_test.py
  33. 7
    7
      tests/unit_tests/early_stop_test.py
  34. 1
    1
      tests/unit_tests/factories_test.py
  35. 1
    21
      tests/unit_tests/initialize_with_dataloaders_test.py
  36. 2
    2
      tests/unit_tests/load_ema_ckpt_test.py
  37. 1
    1
      tests/unit_tests/lr_cooldown_test.py
  38. 4
    4
      tests/unit_tests/lr_warmup_test.py
  39. 1
    1
      tests/unit_tests/phase_context_test.py
  40. 1
    1
      tests/unit_tests/phase_delegates_test.py
  41. 3
    3
      tests/unit_tests/pretrained_models_unit_test.py
  42. 1
    1
      tests/unit_tests/save_ckpt_test.py
  43. 1
    1
      tests/unit_tests/strictload_enum_test.py
  44. 6
    7
      tests/unit_tests/test_without_train_test.py
  45. 1
    1
      tests/unit_tests/train_logging_test.py
  46. 7
    7
      tests/unit_tests/train_with_intialized_param_args_test.py
  47. 2
    2
      tests/unit_tests/train_with_precise_bn_test.py
  48. 1
    1
      tests/unit_tests/update_param_groups_unit_test.py
  49. 1
    1
      tutorials/what_are_recipes_and_how_to_use.ipynb
@@ -25,7 +25,6 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
 super_gradients.init_trainer()
 
 trainer = Trainer("resnet18_qat_example",
-                  model_checkpoints_location='local',
                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
 
 train_loader = dataloaders.imagenet_train()
Discard
@@ -22,7 +22,7 @@ training_hyperparams:
   resume: ${resume}
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 
 architecture: resnet18_cifar
Discard
@@ -78,7 +78,7 @@ checkpoint_params:
 
 experiment_name: ${architecture}_cityscapes
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 
 multi_gpu: DDP
Discard
@@ -44,7 +44,7 @@ arch_params:
   strict_load: no_key_matching
 
 load_checkpoint: False
-model_checkpoints_location: local
+
 ckpt_root_dir:
 
 resume: False
Discard
@@ -24,7 +24,7 @@ checkpoint_params:
 architecture: stdc1_seg
 experiment_name: ${architecture}_cityscapes
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 
 multi_gpu: DDP
Discard
@@ -34,7 +34,7 @@ val_dataloader: coco2017_val
 architecture: ssd_lite_mobilenet_v2
 
 data_loader_num_workers: 8
-model_checkpoints_location: local
+
 experiment_suffix: res${dataset_params.train_image_size}
 experiment_name: ${architecture}_coco_${experiment_suffix}
 
Discard
@@ -40,7 +40,7 @@ defaults:
 train_dataloader: coco2017_train
 val_dataloader: coco2017_val
 
-model_checkpoints_location: local
+
 
 load_checkpoint: False
 resume: False
Discard
@@ -9,7 +9,7 @@
 #   0. Make sure that the data is stored in dataset_params.dataset_dir or add "dataset_params.data_dir=<PATH-TO-DATASET>" at the end of the command below (feel free to check ReadMe)
 #   1. Move to the project root (where you will find the ReadMe and src folder)
 #   2. Run the command:
-#       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco_segmentation_shelfnet_lw --model_checkpoints_location=<checkpoint-location>
+#       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco_segmentation_shelfnet_lw
 
 
 # /!\ THIS RECIPE IS NOT SUPPORTED AT THE MOMENT /!\
Discard
@@ -27,7 +27,7 @@ training_hyperparams:
 
 experiment_name: efficientnet_b0_imagenet
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 
 multi_gpu: DDP
Discard
@@ -25,7 +25,7 @@ arch_params:
 
 data_loader_num_workers: 8
 
-model_checkpoints_location: local
+
 resume: False
 training_hyperparams:
   resume: ${resume}
Discard
@@ -8,7 +8,7 @@ defaults:
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 
-model_checkpoints_location: local
+
 resume: False
 training_hyperparams:
   resume: ${resume}
Discard
@@ -36,7 +36,7 @@ arch_params:
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 
-model_checkpoints_location: local
+
 load_checkpoint: False
 resume: False
 training_hyperparams:
Discard
@@ -25,7 +25,7 @@ train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 resume: False
 training_hyperparams:
   resume: ${resume}
Discard
@@ -24,7 +24,7 @@ arch_params:
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 
-model_checkpoints_location: local
+
 resume: False
 training_hyperparams:
   resume: ${resume}
Discard
@@ -66,7 +66,7 @@ student_checkpoint_params:
   pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
 
 
-model_checkpoints_location: local
+
 
 
 run_teacher_on_eval: True
Discard
@@ -21,7 +21,7 @@ defaults:
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 
-model_checkpoints_location: local
+
 
 resume: False
 training_hyperparams:
Discard
@@ -1,4 +1,6 @@
 resume: False # whether to continue training from ckpt with the same experiment name.
+resume_path: # Explicit checkpoint path (.pth file) to use to resume training.
+ckpt_name: ckpt_latest.pth  # The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and resume_path=None
 lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
 lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
 lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
Discard
@@ -9,7 +9,7 @@ from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.sg_trainer import Trainer
-from typing import Union, List, Any
+from typing import Union
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
@@ -20,7 +20,6 @@ from super_gradients.training.exceptions.kd_trainer_exceptions import Architectu
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     TeacherKnowledgeException, UndefinedNumClassesException
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 
@@ -29,15 +28,8 @@ logger = get_logger(__name__)
 
 class KDTrainer(Trainer):
     def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True,
-                 ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
-                 train_loader: DataLoader = None,
-                 valid_loader: DataLoader = None, test_loader: DataLoader = None, classes: List[Any] = None):
-
-        super().__init__(experiment_name, device, multi_gpu, model_checkpoints_location, overwrite_local_checkpoint,
-                         ckpt_name, post_prediction_callback,
-                         ckpt_root_dir, train_loader, valid_loader, test_loader, classes)
+                 ckpt_root_dir: str = None):
+        super().__init__(experiment_name=experiment_name, device=device, multi_gpu=multi_gpu, ckpt_root_dir=ckpt_root_dir)
         self.student_architecture = None
         self.teacher_architecture = None
         self.student_arch_params = None
Discard
@@ -63,6 +63,7 @@ DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
                            },
                            "resume": False,
                            "resume_path": None,
+                           "ckpt_name": 'ckpt_latest.pth',
                            "resume_strict_load": False
                            }
 
Discard
@@ -2,14 +2,14 @@ import inspect
 import os
 import sys
 from copy import deepcopy
-from typing import Union, Tuple, Mapping, List, Any
+from typing import Union, Tuple, Mapping
 
 import hydra
 import numpy as np
 import torch
 from omegaconf import DictConfig
 from torch import nn
-from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data import DataLoader
 from torch.cuda.amp import GradScaler, autocast
 from torchmetrics import MetricCollection
 from tqdm import tqdm
@@ -33,14 +33,12 @@ from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils.quantization_utils import QATCallback
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
-from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, \
-    IllegalDataloaderInitialization, GPUModeNotSetupError
+from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
 from super_gradients.training.losses import LOSSES
 from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
     get_logging_values, \
     get_metrics_dict, get_train_loop_description_dict
 from super_gradients.training.params import TrainingParams
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
     reduce_results_tuple_for_ddp, compute_precise_bn_stats, setup_gpu_mode, require_gpu_setup
 from super_gradients.training.utils.ema import ModelEMA
@@ -76,30 +74,19 @@ class Trainer:
         returns the test loss, accuracy and runtime
     """
 
-    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local',
-                 overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
-                 train_loader: DataLoader = None, valid_loader: DataLoader = None, test_loader: DataLoader = None,
-                 classes: List[Any] = None):
+    def __init__(self, experiment_name: str, device: str = None,
+                 multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
+                 ckpt_root_dir: str = None):
         """
 
         :param experiment_name:                      Used for logging and loading purposes
         :param device:                          If equal to 'cpu' runs on the CPU otherwise on GPU
         :param multi_gpu:                       If True, runs on all available devices
-        :param model_checkpoints_location:      If set to 's3' saves the Checkpoints in AWS S3
                                                 otherwise saves the Checkpoints Locally
-        :param overwrite_local_checkpoint:      If set to False keeps the current local checkpoint when importing
                                                 checkpoint from cloud service, otherwise overwrites the local checkpoints file
-        :param ckpt_name:                       The Checkpoint to Load
         :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will
                                                 reside. When none is give, it is assumed that
                                                 pkg_resources.resource_filename('checkpoints', "") exists and will be used.
-        :param train_loader:                    Training set Dataloader instead of using DatasetInterface, must pass "valid_loader"
-                                                and "classes" along with it
-        :param valid_loader:                    Validation set Dataloader
-        :param test_loader:                     Test set Dataloader
-        :param classes:                         List of class labels
 
         """
         # SET THE EMPTY PROPERTIES
@@ -109,7 +96,6 @@ class Trainer:
         self.ema_model = None
         self.sg_logger = None
         self.update_param_groups = None
-        self.post_prediction_callback = None
         self.criterion = None
         self.training_params = None
         self.scaler = None
@@ -144,10 +130,7 @@ class Trainer:
 
         # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
         self.experiment_name = experiment_name
-        self.ckpt_name = ckpt_name
-        self.overwrite_local_checkpoint = overwrite_local_checkpoint
-        self.model_checkpoints_location = model_checkpoints_location
-        self._set_dataset_properties(classes, test_loader, train_loader, valid_loader)
+        self.ckpt_name = None
 
         # CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
         if ckpt_root_dir:
@@ -161,7 +144,6 @@ class Trainer:
         # INITIALIZE THE DEVICE FOR THE MODEL
         self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
 
-        self.post_prediction_callback = post_prediction_callback
         # SET THE DEFAULTS
         # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
 
@@ -221,25 +203,18 @@ class Trainer:
                       valid_loader=val_dataloader,
                       training_params=cfg.training_hyperparams)
 
-    def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
-        if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
-            raise IllegalDataloaderInitialization()
-
-        dataset_params = {"batch_size": train_loader.batch_size if train_loader else None,
-                          "val_batch_size": valid_loader.batch_size if valid_loader else None,
-                          "test_batch_size": test_loader.batch_size if test_loader else None,
-                          "dataset_dir": None,
-                          "s3_link": None}
-
-        if train_loader and self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
-            if not all([isinstance(train_loader.sampler, DistributedSampler),
-                        isinstance(valid_loader.sampler, DistributedSampler),
-                        test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
-                logger.warning(
-                    "DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
-
-        self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
-            HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
+    def _set_dataset_params(self):
+        self.dataset_params = {
+            "train_dataset_params": self.train_loader.dataset.dataset_params if hasattr(self.train_loader.dataset,
+                                                                                        "dataset_params") else None,
+            "train_dataloader_params": self.train_loader.dataloader_params if hasattr(self.train_loader,
+                                                                                      "dataloader_params") else None,
+            "valid_dataset_params": self.valid_loader.dataset.dataset_params if hasattr(self.valid_loader.dataset,
+                                                                                        "dataset_params") else None,
+            "valid_dataloader_params": self.valid_loader.dataloader_params if hasattr(self.valid_loader,
+                                                                                      "dataloader_params") else None
+        }
+        self.dataset_params = HpmStruct(**self.dataset_params)
 
     def _set_ckpt_loading_attributes(self):
         """
@@ -408,8 +383,7 @@ class Trainer:
                                                                source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                                metric_to_watch=self.metric_to_watch,
                                                                metric_idx=self.metric_idx_in_results_tuple,
-                                                               load_checkpoint=self.load_checkpoint,
-                                                               model_checkpoints_location=self.model_checkpoints_location)
+                                                               load_checkpoint=self.load_checkpoint)
 
     def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
         """
@@ -520,6 +494,7 @@ class Trainer:
         self.load_ema_as_net = False
         self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
         self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
+        self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", 'ckpt_latest.pth')
         self._load_checkpoint_to_model()
 
     def _init_arch_params(self):
@@ -546,6 +521,21 @@ class Trainer:
             :param train_loader: Dataloader for train set.
             :param valid_loader: Dataloader for validation.
             :param training_params:
+
+                - `resume` : bool (default=False)
+
+                    Whether to continue training from ckpt with the same experiment name
+                     (i.e resume from CKPT_ROOT_DIR/EXPERIMENT_NAME/CKPT_NAME)
+
+                - `ckpt_name` : str (default=ckpt_latest.pth)
+
+                    The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and
+                     resume_path=None
+
+                - `resume_path`: str (default=None)
+
+                    Explicit checkpoint path (.pth file) to use to resume training.
+
                 - `max_epochs` : int
 
                     Number of epochs to run training.
@@ -859,6 +849,7 @@ class Trainer:
 
         self.train_loader = train_loader or self.train_loader
         self.valid_loader = valid_loader or self.valid_loader
+        self._set_dataset_params()
 
         self.training_params = TrainingParams()
         self.training_params.override(**training_params)
@@ -1120,10 +1111,6 @@ class Trainer:
             self.phase_callback_handler(Phase.POST_TRAINING, context)
 
             if not self.ddp_silent_mode:
-                if self.model_checkpoints_location != 'local':
-                    logger.info('[CLEANUP] - Saving Checkpoint files')
-                    self.sg_logger.upload()
-
                 self.sg_logger.close()
 
     def _reset_best_metric(self):
@@ -1367,10 +1354,7 @@ class Trainer:
             ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                   experiment_name=self.experiment_name,
                                                   ckpt_name=self.ckpt_name,
-                                                  model_checkpoints_location=self.model_checkpoints_location,
-                                                  external_checkpoint_path=self.external_checkpoint_path,
-                                                  overwrite_local_checkpoint=self.overwrite_local_checkpoint,
-                                                  load_weights_only=self.load_weights_only)
+                                                  external_checkpoint_path=self.external_checkpoint_path)
 
             # LOAD CHECKPOINT TO MODEL
             self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
@@ -1390,7 +1374,7 @@ class Trainer:
         self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
         self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
 
-    def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None, post_prediction_callback=None,
+    def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None,
                        test_metrics_list=None,
                        loss_logging_items_names=None, test_phase_callbacks=None):
         """Run commands that are common to all models"""
@@ -1400,7 +1384,6 @@ class Trainer:
         # IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
         self.test_loader = test_loader or self.test_loader
         self.criterion = loss or self.criterion
-        self.post_prediction_callback = post_prediction_callback or self.post_prediction_callback
         self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
         self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
 
@@ -1445,7 +1428,7 @@ class Trainer:
 
         # OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
         general_sg_logger_params = {'experiment_name': self.experiment_name,
-                                    'storage_location': self.model_checkpoints_location,
+                                    'storage_location': 'local',
                                     'resumed': self.load_checkpoint,
                                     'training_params': self.training_params,
                                     'checkpoints_dir_path': self.checkpoints_dir_path}
Discard
@@ -10,8 +10,7 @@ except (ModuleNotFoundError, ImportError, NameError):
     from torch.hub import _download_url_to_file as download_url_to_file
 
 
-def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str,
-                        overwrite_local_checkpoint: bool, load_weights_only: bool):
+def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
     """
     Gets the local path to the checkpoint file, which will be:
         - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
@@ -24,30 +23,11 @@ def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt
     @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
     @param experiment_name: experiment name attr in trainer
     @param ckpt_name: checkpoint filename
-    @param model_checkpoints_location: S3, local ot URL
     @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
-    @param overwrite_local_checkpoint: whether to overwrite the checkpoint file with the same name when downloading from S3.
-    @param load_weights_only: whether to load the network's state dict only.
     @return:
     """
     source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
-    if model_checkpoints_location == 'local':
-        ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
-
-    # COPY THE DATA FROM 'S3'/'URL' INTO A LOCAL DIRECTORY
-    elif model_checkpoints_location.startswith('s3') or model_checkpoints_location == 'url':
-        # COPY REMOTE DATA TO A LOCAL DIRECTORY AND GET THAT DIRECTORYs NAME
-        ckpt_local_path = copy_ckpt_to_local_folder(local_ckpt_destination_dir=experiment_name,
-                                                    ckpt_filename=ckpt_name,
-                                                    remote_ckpt_source_dir=source_ckpt_folder_name,
-                                                    path_src=model_checkpoints_location,
-                                                    overwrite_local_ckpt=overwrite_local_checkpoint,
-                                                    load_weights_only=load_weights_only)
-
-    else:
-        # ERROR IN USER CODE FLOW - THIS WILL EVENTUALLY RAISE AN EXCEPTION
-        raise NotImplementedError(
-            'model_checkpoints_data_source: ' + str(model_checkpoints_location) + 'not supported')
+    ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
 
     return ckpt_local_path
 
Discard
@@ -19,7 +19,6 @@ class ModelWeightAveraging:
                  source_ckpt_folder_name=None, metric_to_watch='acc',
                  metric_idx=1, load_checkpoint=False,
                  number_of_models_to_average=10,
-                 model_checkpoints_location='local'
                  ):
         """
         Init the ModelWeightAveraging
@@ -45,7 +44,6 @@ class ModelWeightAveraging:
                                                                   source_ckpt_folder_name=source_ckpt_folder_name,
                                                                   ckpt_filename="averaging_snapshots.pkl",
                                                                   load_weights_only=False,
-                                                                  model_checkpoints_location=model_checkpoints_location,
                                                                   overwrite_local_ckpt=True)
 
         else:
Discard
@@ -16,7 +16,7 @@ from super_gradients.training.dataloaders.dataloaders import (
 class TestCifarTrainer(unittest.TestCase):
     def test_train_cifar10_dataloader(self):
         super_gradients.init_trainer()
-        trainer = Trainer("test", model_checkpoints_location="local")
+        trainer = Trainer("test")
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
         trainer.train(
@@ -35,7 +35,7 @@ class TestCifarTrainer(unittest.TestCase):
 
     def test_train_cifar100_dataloader(self):
         super_gradients.init_trainer()
-        trainer = Trainer("test", model_checkpoints_location="local")
+        trainer = Trainer("test")
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
         model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
         trainer.train(
Discard
@@ -38,7 +38,7 @@ class TestTrainer(unittest.TestCase):
 
     @staticmethod
     def get_classification_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18", num_classes=5)
         return trainer, model
 
Discard
@@ -69,7 +69,7 @@ class ConversionCallbackTest(unittest.TestCase):
                 "phase_callbacks": phase_callbacks,
             }
 
-            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
+            trainer = Trainer(f"{architecture}_example",
                               ckpt_root_dir=checkpoint_dir)
             model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             try:
@@ -102,7 +102,7 @@ class ConversionCallbackTest(unittest.TestCase):
 
         for architecture in SEMANTIC_SEGMENTATION:
             model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
-            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
+            trainer = Trainer(f"{architecture}_example",
                               ckpt_root_dir=checkpoint_dir)
             model = models.get(model_name=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
 
Discard
@@ -10,7 +10,7 @@ from deci_lab_client.models import Metric, QuantizationLevel, ModelMetadata, Opt
 
 class DeciLabUploadTest(unittest.TestCase):
     def setUp(self) -> None:
-        self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
+        self.trainer = Trainer("deci_lab_export_test_model")
 
     def test_train_with_deci_lab_integration(self):
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
Discard
@@ -23,7 +23,7 @@ class CallWrapper:
 class EMAIntegrationTest(unittest.TestCase):
 
     def _init_model(self) -> None:
-        self.trainer = Trainer("resnet18_cifar_ema_test", model_checkpoints_location='local',
+        self.trainer = Trainer("resnet18_cifar_ema_test",
                                device='cpu', multi_gpu=MultiGPUMode.OFF)
         self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
 
Discard
@@ -30,7 +30,7 @@ class LRTest(unittest.TestCase):
 
     @staticmethod
     def get_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18_cifar", num_classes=5)
         return trainer, model
 
Discard
@@ -235,7 +235,7 @@ class PretrainedModelsTest(unittest.TestCase):
         }
 
     def test_pretrained_resnet50_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet50', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params)
@@ -244,7 +244,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
 
     def test_transfer_learning_resnet50_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -253,7 +253,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_resnet34_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet34', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet34',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -263,7 +263,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
 
     def test_transfer_learning_resnet34_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -272,7 +272,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_resnet18_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet18', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet18',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -282,7 +282,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
 
     def test_transfer_learning_resnet18_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -291,7 +291,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -301,7 +301,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
 
     def test_transfer_learning_regnetY800_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -310,7 +310,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_regnetY600_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY600',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -320,7 +320,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
 
     def test_transfer_learning_regnetY600_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -329,7 +329,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_regnetY400_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY400',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -339,7 +339,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
 
     def test_transfer_learning_regnetY400_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -348,7 +348,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_regnetY200_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY200',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -358,7 +358,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
 
     def test_transfer_learning_regnetY200_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -367,7 +367,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
@@ -377,7 +377,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
 
     def test_transfer_learning_repvgg_a0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -386,7 +386,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_regseg48_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_regseg48',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -396,7 +396,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
 
     def test_transfer_learning_regseg48_cityscapes(self):
-        trainer = Trainer('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('regseg48_cityscapes_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -405,7 +405,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       training_params=self.regseg_transfer_segmentation_train_params)
 
     def test_pretrained_ddrnet23_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -415,7 +415,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
 
     def test_pretrained_ddrnet23_slim_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -425,7 +425,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
 
     def test_transfer_learning_ddrnet23_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -434,7 +434,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
 
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -443,7 +443,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
 
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
-        trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
+        trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("shelfnet34_lw",
                            arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
@@ -453,7 +453,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
 
     def test_pretrained_efficientnet_b0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_efficientnet_b0',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
@@ -463,7 +463,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
 
     def test_transfer_learning_efficientnet_b0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
@@ -473,7 +473,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
-        trainer = Trainer('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
+        trainer = Trainer('coco_ssd_lite_mobilenet_v2',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ssd_lite_mobilenet_v2",
                            arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"],
@@ -485,7 +485,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
         trainer = Trainer('coco_ssd_lite_mobilenet_v2_transfer_learning',
-                          model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF)
+                          multi_gpu=MultiGPUMode.OFF)
         transfer_arch_params = self.coco_pretrained_arch_params['ssd_lite_mobilenet_v2'].copy()
         transfer_arch_params['num_classes'] = 5
         model = models.get("ssd_lite_mobilenet_v2",
@@ -496,11 +496,10 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_detection_dataset)
 
     def test_pretrained_ssd_mobilenet_v1_coco(self):
-        trainer = Trainer('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
+        trainer = Trainer('coco_ssd_mobilenet_v1',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ssd_mobilenet_v1",
-                           arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"],
-                           **self.coco_pretrained_ckpt_params)
+                           arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"], **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'],
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback,
@@ -509,7 +508,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
 
     def test_pretrained_yolox_s_coco(self):
-        trainer = Trainer('yolox_s', model_checkpoints_location='local',
+        trainer = Trainer('yolox_s',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("yolox_s",
@@ -521,7 +520,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
 
     def test_pretrained_yolox_m_coco(self):
-        trainer = Trainer('yolox_m', model_checkpoints_location='local',
+        trainer = Trainer('yolox_m',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_m",
                            **self.coco_pretrained_ckpt_params)
@@ -532,7 +531,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
 
     def test_pretrained_yolox_l_coco(self):
-        trainer = Trainer('yolox_l', model_checkpoints_location='local',
+        trainer = Trainer('yolox_l',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_l",
                            **self.coco_pretrained_ckpt_params)
@@ -543,7 +542,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
 
     def test_pretrained_yolox_n_coco(self):
-        trainer = Trainer('yolox_n', model_checkpoints_location='local',
+        trainer = Trainer('yolox_n',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("yolox_n",
@@ -555,7 +554,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
 
     def test_pretrained_yolox_t_coco(self):
-        trainer = Trainer('yolox_t', model_checkpoints_location='local',
+        trainer = Trainer('yolox_t',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_t",
                            **self.coco_pretrained_ckpt_params)
@@ -567,7 +566,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_yolox_n_coco(self):
         trainer = Trainer('test_transfer_learning_yolox_n_coco',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
         trainer.train(model=model, training_params=self.transfer_detection_train_params_yolox,
@@ -576,7 +575,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -586,7 +585,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_mobilenet_v3_large_imagenet(self):
-        trainer = Trainer('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v3_large',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -597,7 +596,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_small_transfer_learning',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -607,7 +606,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_mobilenet_v3_small_imagenet(self):
-        trainer = Trainer('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v3_small',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -618,7 +617,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_mobilenet_v2_imagenet(self):
         trainer = Trainer('imagenet_pretrained_mobilenet_v2_transfer_learning',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -628,7 +627,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_mobilenet_v2_imagenet(self):
-        trainer = Trainer('imagenet_mobilenet_v2', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v2',
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -638,7 +637,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
 
     def test_pretrained_stdc1_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg50',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -648,7 +647,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
 
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -657,7 +656,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
 
     def test_pretrained_stdc1_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg75',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -667,7 +666,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
 
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -676,7 +675,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
 
     def test_pretrained_stdc2_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg50',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -686,7 +685,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
 
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -695,7 +694,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
 
     def test_pretrained_stdc2_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg75',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -705,7 +704,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
 
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -715,7 +714,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_vit_base_imagenet21k(self):
         trainer = Trainer('imagenet21k_pretrained_vit_base',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -726,7 +725,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_vit_large_imagenet21k(self):
         trainer = Trainer('imagenet21k_pretrained_vit_large',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -736,7 +735,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_vit_base_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_vit_base', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_vit_base',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                            **self.imagenet_pretrained_ckpt_params)
@@ -746,7 +745,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
 
     def test_pretrained_vit_large_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_vit_large', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_vit_large',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                            **self.imagenet_pretrained_ckpt_params)
@@ -756,7 +755,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
 
     def test_pretrained_beit_base_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_beit_base', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_beit_base',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                            **self.imagenet_pretrained_ckpt_params)
@@ -767,7 +766,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
     def test_transfer_learning_beit_base_imagenet(self):
         trainer = Trainer('test_transfer_learning_beit_base_imagenet',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
 
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -777,7 +776,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
 
     def test_pretrained_pplite_t_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_t_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_t_seg50',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -789,7 +788,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
 
     def test_pretrained_pplite_t_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_t_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_t_seg75',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -801,7 +800,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
 
     def test_pretrained_pplite_b_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_b_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_b_seg50',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
@@ -813,7 +812,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
 
     def test_pretrained_pplite_b_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_b_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_b_seg75',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
Discard
@@ -10,7 +10,7 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
 class QATIntegrationTest(unittest.TestCase):
     def _get_trainer(self, experiment_name):
         trainer = Trainer(experiment_name,
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet18", pretrained_weights="imagenet")
         return trainer, model
Discard
@@ -18,9 +18,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
         """
         # Create dataset
 
-        trainer = Trainer('dataset_statistics_visual_test',
-                          model_checkpoints_location='local',
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer('dataset_statistics_visual_test')
 
         model = models.get("yolox_s")
 
Discard
@@ -12,10 +12,9 @@ class TestDetectionUtils(unittest.TestCase):
     def test_visualization(self):
 
         # Create Yolo model
-        trainer = Trainer('visualization_test',
-                          model_checkpoints_location='local',
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer('visualization_test')
         model = models.get("yolox_n", pretrained_weights="coco")
+        post_prediction_callback = YoloPostPredictionCallback()
 
         # Simulate one iteration of validation subset
         valid_loader = coco2017_val()
@@ -23,7 +22,7 @@ class TestDetectionUtils(unittest.TestCase):
         imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
         targets = core_utils.tensor_container_to_device(targets, trainer.device)
         output = model(imgs)
-        output = trainer.post_prediction_callback(output)
+        output = post_prediction_callback(output)
         # Visualize the batch
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
                                                COCO_DETECTION_CLASSES_LIST, trainer.checkpoints_dir_path)
Discard
@@ -58,7 +58,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=min metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
@@ -80,7 +80,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=max metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
                                    verbose=True)
         phase_callbacks = [early_stop_acc]
@@ -101,7 +101,7 @@ class EarlyStopTest(unittest.TestCase):
         """
         Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", threshold=0.1, verbose=True)
         phase_callbacks = [early_stop_loss]
@@ -121,7 +121,7 @@ class EarlyStopTest(unittest.TestCase):
         """
         Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
                                    verbose=True)
@@ -144,7 +144,7 @@ class EarlyStopTest(unittest.TestCase):
         Test that training stops when monitor value is not a finite number. Test case of NaN and Inf values.
         """
         # test Nan value
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", check_finite=True,
                                     verbose=True)
@@ -162,7 +162,7 @@ class EarlyStopTest(unittest.TestCase):
         self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
 
         # test Inf value
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
@@ -183,7 +183,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for `min_delta` argument, metric value is considered an improvement only if
         current_value - min_delta > best_value
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
                                    min_delta=0.1, verbose=True)
Discard
@@ -11,7 +11,7 @@ from super_gradients.training.metrics import Accuracy, Top5
 class FactoriesTest(unittest.TestCase):
 
     def test_training_with_factories(self):
-        trainer = Trainer("test_train_with_factories", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_factories")
         net = models.get("resnet18", num_classes=5)
         train_params = {"max_epochs": 2,
                         "lr_updates": [1],
Discard
@@ -6,7 +6,6 @@ from super_gradients import Trainer
 import torch
 from torch.utils.data import TensorDataset, DataLoader
 from super_gradients.training.metrics import Accuracy
-from super_gradients.training.exceptions.sg_trainer_exceptions import IllegalDataloaderInitialization
 
 
 class InitializeWithDataloadersTest(unittest.TestCase):
@@ -26,27 +25,8 @@ class InitializeWithDataloadersTest(unittest.TestCase):
         label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
 
-    def test_initialization_rules(self):
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, test_loader=self.testcase_testloader,
-                classes=self.testcase_classes)
-
     def test_train_with_dataloaders(self):
-        trainer = Trainer(experiment_name="test_name", model_checkpoints_location="local")
+        trainer = Trainer(experiment_name="test_name")
         model = models.get("resnet18", num_classes=5)
         trainer.train(model=model,
                       training_params={"max_epochs": 2,
Discard
@@ -29,7 +29,7 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
     def test_ema_ckpt_reload(self):
         # Define Model
         net = LeNet()
-        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("ema_ckpt_test")
         trainer.train(model=net, training_params=self.train_params,
                       train_loader=classification_test_dataloader(),
                       valid_loader=classification_test_dataloader())
@@ -38,7 +38,7 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
 
         # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
         net = LeNet()
-        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("ema_ckpt_test")
 
         net_collector = PreTrainingEMANetCollector()
         self.train_params["resume"] = True
Discard
@@ -10,7 +10,7 @@ class LRCooldownTest(unittest.TestCase):
     def test_lr_cooldown_with_lr_scheduling(self):
         # Define Model
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
@@ -38,7 +38,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup(self):
         # Define Model
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -60,7 +60,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup_with_lr_scheduling(self):
         # Define model
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -85,7 +85,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_warmup_initial_lr(self):
         # Define model
         net = LeNet()
-        trainer = Trainer("test_warmup_initial_lr", model_checkpoints_location='local')
+        trainer = Trainer("test_warmup_initial_lr")
 
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -107,7 +107,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_custom_lr_warmup(self):
         # Define model
         net = LeNet()
-        trainer = Trainer("custom_lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("custom_lr_warmup_test")
 
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
@@ -12,7 +12,7 @@ from torchmetrics import MetricCollection
 
 class PhaseContextTest(unittest.TestCase):
     def context_information_in_train_test(self):
-        trainer = Trainer("context_information_in_train_test", model_checkpoints_location='local')
+        trainer = Trainer("context_information_in_train_test")
 
         net = ResNet18(num_classes=5, arch_params={})
 
Discard
@@ -31,7 +31,7 @@ class ContextMethodsCheckerCallback(PhaseCallback):
 class ContextMethodsTest(unittest.TestCase):
     def test_access_to_methods_by_phase(self):
         net = LeNet()
-        trainer = Trainer("test_access_to_methods_by_phase", model_checkpoints_location='local')
+        trainer = Trainer("test_access_to_methods_by_phase")
 
         phase_callbacks = []
         for phase in Phase:
Discard
@@ -14,21 +14,21 @@ class PretrainedModelsUnitTest(unittest.TestCase):
         self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
 
     def test_pretrained_resnet50_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet50_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet50", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
 
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY800_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY800", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
 
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
Discard
@@ -18,7 +18,7 @@ class SaveCkptListUnitTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True}
 
         # Define Model
-        trainer = Trainer("save_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("save_ckpt_test")
 
         # Build Model
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
Discard
@@ -53,7 +53,7 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(cls.change_state_dict_keys(cls.original_torch_model.state_dict()), cls.checkpoint_diff_keys_path)
 
         # Save the model's state_dict checkpoint in Trainer format
-        cls.trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')  # Saves in /checkpoints
+        cls.trainer = Trainer("load_checkpoint_test")  # Saves in /checkpoints
         cls.trainer.set_net(cls.original_torch_model)
         # FIXME: after uniting init and build_model we should remove this
         cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
Discard
@@ -6,9 +6,9 @@ from super_gradients.training.dataloaders.dataloaders import classification_test
     detection_test_dataloader, segmentation_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training import MultiGPUMode, models
-from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
+from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 
 
 class TestWithoutTrainTest(unittest.TestCase):
@@ -26,22 +26,21 @@ class TestWithoutTrainTest(unittest.TestCase):
 
     @staticmethod
     def get_classification_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18", num_classes=5)
         return trainer, model
 
     @staticmethod
     def get_detection_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local',
-                          multi_gpu=MultiGPUMode.OFF,
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer(name,
+                          multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_s", num_classes=5)
         return trainer, model
 
     @staticmethod
     def get_segmentation_trainer(name=''):
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
-        trainer = Trainer(name, model_checkpoints_location='local', multi_gpu=False)
+        trainer = Trainer(name)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         return trainer, model
 
@@ -52,7 +51,7 @@ class TestWithoutTrainTest(unittest.TestCase):
 
         trainer, model = self.get_detection_trainer(self.folder_names[1])
 
-        test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
+        test_metrics = [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=5)]
 
         assert isinstance(trainer.test(model=model, silent_mode=True,
                                        test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), tuple)
Discard
@@ -11,7 +11,7 @@ import shutil
 
 class SgTrainerLoggingTest(unittest.TestCase):
     def test_train_logging(self):
-        trainer = Trainer("test_train_with_full_log", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_full_log")
 
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
Discard
@@ -19,7 +19,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
     """
 
     def test_train_with_external_criterion(self):
-        trainer = Trainer("external_criterion_test", model_checkpoints_location='local')
+        trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -33,7 +33,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
     def test_train_with_external_optimizer(self):
-        trainer = Trainer("external_optimizer_test", model_checkpoints_location='local')
+        trainer = Trainer("external_optimizer_test")
         dataloader = classification_test_dataloader(batch_size=10)
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -47,7 +47,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
     def test_train_with_external_scheduler(self):
-        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
 
         lr = 0.3
@@ -66,7 +66,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         self.assertTrue(lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1)
 
     def test_train_with_external_scheduler_class(self):
-        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -81,7 +81,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
     def test_train_with_reduce_on_plateau(self):
-        trainer = Trainer("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_reduce_on_plateau_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
 
         lr = 0.3
@@ -101,7 +101,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         self.assertTrue(lr_scheduler._last_lr[0] == lr * 0.1)
 
     def test_train_with_external_metric(self):
-        trainer = Trainer("external_metric_test", model_checkpoints_location='local')
+        trainer = Trainer("external_metric_test")
         dataloader = classification_test_dataloader(batch_size=10)
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -114,7 +114,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
     def test_train_with_external_dataloaders(self):
-        trainer = Trainer("external_data_loader_test", model_checkpoints_location='local')
+        trainer = Trainer("external_data_loader_test")
 
         batch_size = 5
         trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))),
Discard
@@ -12,7 +12,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
     """
 
     def test_train_with_precise_bn_explicit_size(self):
-        trainer = Trainer("test_train_with_precise_bn_explicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_explicit_size")
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
@@ -26,7 +26,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
                       valid_loader=classification_test_dataloader(batch_size=10))
 
     def test_train_with_precise_bn_implicit_size(self):
-        trainer = Trainer("test_train_with_precise_bn_implicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_implicit_size")
 
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
Discard
@@ -29,7 +29,7 @@ class UpdateParamGroupsTest(unittest.TestCase):
     def test_lr_scheduling_with_update_param_groups(self):
         # Define Model
         net = TestNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
Discard