Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
49 changed files with 176 additions and 246 deletions
  1. 0
    1
      src/super_gradients/examples/resnet_qat/resnet_qat_example.py
  2. 1
    1
      src/super_gradients/recipes/cifar10_resnet.yaml
  3. 1
    1
      src/super_gradients/recipes/cityscapes_ddrnet.yaml
  4. 1
    1
      src/super_gradients/recipes/cityscapes_regseg48.yaml
  5. 1
    1
      src/super_gradients/recipes/cityscapes_stdc_base.yaml
  6. 1
    1
      src/super_gradients/recipes/coco2017_ssd_lite_mobilenet_v2.yaml
  7. 1
    1
      src/super_gradients/recipes/coco2017_yolox.yaml
  8. 1
    1
      src/super_gradients/recipes/coco_segmentation_shelfnet_lw.yaml
  9. 1
    1
      src/super_gradients/recipes/imagenet_efficientnet.yaml
  10. 1
    1
      src/super_gradients/recipes/imagenet_mobilenetv2.yaml
  11. 1
    1
      src/super_gradients/recipes/imagenet_mobilenetv3_base.yaml
  12. 1
    1
      src/super_gradients/recipes/imagenet_regnetY.yaml
  13. 1
    1
      src/super_gradients/recipes/imagenet_repvgg.yaml
  14. 1
    1
      src/super_gradients/recipes/imagenet_resnet50.yaml
  15. 1
    1
      src/super_gradients/recipes/imagenet_resnet50_kd.yaml
  16. 1
    1
      src/super_gradients/recipes/imagenet_vit_base.yaml
  17. 2
    0
      src/super_gradients/recipes/training_hyperparams/default_train_params.yaml
  18. 3
    11
      src/super_gradients/training/kd_trainer/kd_trainer.py
  19. 1
    0
      src/super_gradients/training/params.py
  20. 40
    57
      src/super_gradients/training/sg_trainer/sg_trainer.py
  21. 2
    22
      src/super_gradients/training/utils/checkpoint_utils.py
  22. 0
    2
      src/super_gradients/training/utils/weight_averaging_utils.py
  23. 2
    2
      tests/end_to_end_tests/cifar_trainer_test.py
  24. 1
    1
      tests/end_to_end_tests/trainer_test.py
  25. 2
    2
      tests/integration_tests/conversion_callback_test.py
  26. 1
    1
      tests/integration_tests/deci_lab_export_test.py
  27. 1
    1
      tests/integration_tests/ema_train_integration_test.py
  28. 1
    1
      tests/integration_tests/lr_test.py
  29. 59
    60
      tests/integration_tests/pretrained_models_test.py
  30. 1
    1
      tests/integration_tests/qat_integration_test.py
  31. 1
    3
      tests/unit_tests/dataset_statistics_test.py
  32. 3
    4
      tests/unit_tests/detection_utils_test.py
  33. 7
    7
      tests/unit_tests/early_stop_test.py
  34. 1
    1
      tests/unit_tests/factories_test.py
  35. 1
    21
      tests/unit_tests/initialize_with_dataloaders_test.py
  36. 2
    2
      tests/unit_tests/load_ema_ckpt_test.py
  37. 1
    1
      tests/unit_tests/lr_cooldown_test.py
  38. 4
    4
      tests/unit_tests/lr_warmup_test.py
  39. 1
    1
      tests/unit_tests/phase_context_test.py
  40. 1
    1
      tests/unit_tests/phase_delegates_test.py
  41. 3
    3
      tests/unit_tests/pretrained_models_unit_test.py
  42. 1
    1
      tests/unit_tests/save_ckpt_test.py
  43. 1
    1
      tests/unit_tests/strictload_enum_test.py
  44. 6
    7
      tests/unit_tests/test_without_train_test.py
  45. 1
    1
      tests/unit_tests/train_logging_test.py
  46. 7
    7
      tests/unit_tests/train_with_intialized_param_args_test.py
  47. 2
    2
      tests/unit_tests/train_with_precise_bn_test.py
  48. 1
    1
      tests/unit_tests/update_param_groups_unit_test.py
  49. 1
    1
      tutorials/what_are_recipes_and_how_to_use.ipynb
@@ -25,7 +25,6 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
 super_gradients.init_trainer()
 super_gradients.init_trainer()
 
 
 trainer = Trainer("resnet18_qat_example",
 trainer = Trainer("resnet18_qat_example",
-                  model_checkpoints_location='local',
                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
 
 
 train_loader = dataloaders.imagenet_train()
 train_loader = dataloaders.imagenet_train()
Discard
@@ -22,7 +22,7 @@ training_hyperparams:
   resume: ${resume}
   resume: ${resume}
 
 
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 architecture: resnet18_cifar
 architecture: resnet18_cifar
Discard
@@ -78,7 +78,7 @@ checkpoint_params:
 
 
 experiment_name: ${architecture}_cityscapes
 experiment_name: ${architecture}_cityscapes
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu: DDP
 multi_gpu: DDP
Discard
@@ -44,7 +44,7 @@ arch_params:
   strict_load: no_key_matching
   strict_load: no_key_matching
 
 
 load_checkpoint: False
 load_checkpoint: False
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 resume: False
 resume: False
Discard
@@ -24,7 +24,7 @@ checkpoint_params:
 architecture: stdc1_seg
 architecture: stdc1_seg
 experiment_name: ${architecture}_cityscapes
 experiment_name: ${architecture}_cityscapes
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu: DDP
 multi_gpu: DDP
Discard
@@ -34,7 +34,7 @@ val_dataloader: coco2017_val
 architecture: ssd_lite_mobilenet_v2
 architecture: ssd_lite_mobilenet_v2
 
 
 data_loader_num_workers: 8
 data_loader_num_workers: 8
-model_checkpoints_location: local
+
 experiment_suffix: res${dataset_params.train_image_size}
 experiment_suffix: res${dataset_params.train_image_size}
 experiment_name: ${architecture}_coco_${experiment_suffix}
 experiment_name: ${architecture}_coco_${experiment_suffix}
 
 
Discard
@@ -40,7 +40,7 @@ defaults:
 train_dataloader: coco2017_train
 train_dataloader: coco2017_train
 val_dataloader: coco2017_val
 val_dataloader: coco2017_val
 
 
-model_checkpoints_location: local
+
 
 
 load_checkpoint: False
 load_checkpoint: False
 resume: False
 resume: False
Discard
@@ -9,7 +9,7 @@
 #   0. Make sure that the data is stored in dataset_params.dataset_dir or add "dataset_params.data_dir=<PATH-TO-DATASET>" at the end of the command below (feel free to check ReadMe)
 #   0. Make sure that the data is stored in dataset_params.dataset_dir or add "dataset_params.data_dir=<PATH-TO-DATASET>" at the end of the command below (feel free to check ReadMe)
 #   1. Move to the project root (where you will find the ReadMe and src folder)
 #   1. Move to the project root (where you will find the ReadMe and src folder)
 #   2. Run the command:
 #   2. Run the command:
-#       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco_segmentation_shelfnet_lw --model_checkpoints_location=<checkpoint-location>
+#       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco_segmentation_shelfnet_lw
 
 
 
 
 # /!\ THIS RECIPE IS NOT SUPPORTED AT THE MOMENT /!\
 # /!\ THIS RECIPE IS NOT SUPPORTED AT THE MOMENT /!\
Discard
@@ -27,7 +27,7 @@ training_hyperparams:
 
 
 experiment_name: efficientnet_b0_imagenet
 experiment_name: efficientnet_b0_imagenet
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu: DDP
 multi_gpu: DDP
Discard
@@ -25,7 +25,7 @@ arch_params:
 
 
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -8,7 +8,7 @@ defaults:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -36,7 +36,7 @@ arch_params:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 load_checkpoint: False
 load_checkpoint: False
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
Discard
@@ -25,7 +25,7 @@ train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -24,7 +24,7 @@ arch_params:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -66,7 +66,7 @@ student_checkpoint_params:
   pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
   pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
 
 
 
 
-model_checkpoints_location: local
+
 
 
 
 
 run_teacher_on_eval: True
 run_teacher_on_eval: True
Discard
@@ -21,7 +21,7 @@ defaults:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 
 
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
Discard
@@ -1,4 +1,6 @@
 resume: False # whether to continue training from ckpt with the same experiment name.
 resume: False # whether to continue training from ckpt with the same experiment name.
+resume_path: # Explicit checkpoint path (.pth file) to use to resume training.
+ckpt_name: ckpt_latest.pth  # The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and resume_path=None
 lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
 lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
 lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
 lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
 lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
 lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
Discard
@@ -9,7 +9,7 @@ from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
-from typing import Union, List, Any
+from typing import Union
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
@@ -20,7 +20,6 @@ from super_gradients.training.exceptions.kd_trainer_exceptions import Architectu
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     TeacherKnowledgeException, UndefinedNumClassesException
     TeacherKnowledgeException, UndefinedNumClassesException
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 
 
@@ -29,15 +28,8 @@ logger = get_logger(__name__)
 
 
 class KDTrainer(Trainer):
 class KDTrainer(Trainer):
     def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
     def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True,
-                 ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
-                 train_loader: DataLoader = None,
-                 valid_loader: DataLoader = None, test_loader: DataLoader = None, classes: List[Any] = None):
-
-        super().__init__(experiment_name, device, multi_gpu, model_checkpoints_location, overwrite_local_checkpoint,
-                         ckpt_name, post_prediction_callback,
-                         ckpt_root_dir, train_loader, valid_loader, test_loader, classes)
+                 ckpt_root_dir: str = None):
+        super().__init__(experiment_name=experiment_name, device=device, multi_gpu=multi_gpu, ckpt_root_dir=ckpt_root_dir)
         self.student_architecture = None
         self.student_architecture = None
         self.teacher_architecture = None
         self.teacher_architecture = None
         self.student_arch_params = None
         self.student_arch_params = None
Discard
@@ -63,6 +63,7 @@ DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
                            },
                            },
                            "resume": False,
                            "resume": False,
                            "resume_path": None,
                            "resume_path": None,
+                           "ckpt_name": 'ckpt_latest.pth',
                            "resume_strict_load": False
                            "resume_strict_load": False
                            }
                            }
 
 
Discard
@@ -2,14 +2,14 @@ import inspect
 import os
 import os
 import sys
 import sys
 from copy import deepcopy
 from copy import deepcopy
-from typing import Union, Tuple, Mapping, List, Any
+from typing import Union, Tuple, Mapping
 
 
 import hydra
 import hydra
 import numpy as np
 import numpy as np
 import torch
 import torch
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 from torch import nn
 from torch import nn
-from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data import DataLoader
 from torch.cuda.amp import GradScaler, autocast
 from torch.cuda.amp import GradScaler, autocast
 from torchmetrics import MetricCollection
 from torchmetrics import MetricCollection
 from tqdm import tqdm
 from tqdm import tqdm
@@ -33,14 +33,12 @@ from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils.quantization_utils import QATCallback
 from super_gradients.training.utils.quantization_utils import QATCallback
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
-from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, \
-    IllegalDataloaderInitialization, GPUModeNotSetupError
+from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
 from super_gradients.training.losses import LOSSES
 from super_gradients.training.losses import LOSSES
 from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
 from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
     get_logging_values, \
     get_logging_values, \
     get_metrics_dict, get_train_loop_description_dict
     get_metrics_dict, get_train_loop_description_dict
 from super_gradients.training.params import TrainingParams
 from super_gradients.training.params import TrainingParams
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
 from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
     reduce_results_tuple_for_ddp, compute_precise_bn_stats, setup_gpu_mode, require_gpu_setup
     reduce_results_tuple_for_ddp, compute_precise_bn_stats, setup_gpu_mode, require_gpu_setup
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.ema import ModelEMA
@@ -76,30 +74,19 @@ class Trainer:
         returns the test loss, accuracy and runtime
         returns the test loss, accuracy and runtime
     """
     """
 
 
-    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local',
-                 overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
-                 train_loader: DataLoader = None, valid_loader: DataLoader = None, test_loader: DataLoader = None,
-                 classes: List[Any] = None):
+    def __init__(self, experiment_name: str, device: str = None,
+                 multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
+                 ckpt_root_dir: str = None):
         """
         """
 
 
         :param experiment_name:                      Used for logging and loading purposes
         :param experiment_name:                      Used for logging and loading purposes
         :param device:                          If equal to 'cpu' runs on the CPU otherwise on GPU
         :param device:                          If equal to 'cpu' runs on the CPU otherwise on GPU
         :param multi_gpu:                       If True, runs on all available devices
         :param multi_gpu:                       If True, runs on all available devices
-        :param model_checkpoints_location:      If set to 's3' saves the Checkpoints in AWS S3
                                                 otherwise saves the Checkpoints Locally
                                                 otherwise saves the Checkpoints Locally
-        :param overwrite_local_checkpoint:      If set to False keeps the current local checkpoint when importing
                                                 checkpoint from cloud service, otherwise overwrites the local checkpoints file
                                                 checkpoint from cloud service, otherwise overwrites the local checkpoints file
-        :param ckpt_name:                       The Checkpoint to Load
         :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will
         :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will
                                                 reside. When none is give, it is assumed that
                                                 reside. When none is give, it is assumed that
                                                 pkg_resources.resource_filename('checkpoints', "") exists and will be used.
                                                 pkg_resources.resource_filename('checkpoints', "") exists and will be used.
-        :param train_loader:                    Training set Dataloader instead of using DatasetInterface, must pass "valid_loader"
-                                                and "classes" along with it
-        :param valid_loader:                    Validation set Dataloader
-        :param test_loader:                     Test set Dataloader
-        :param classes:                         List of class labels
 
 
         """
         """
         # SET THE EMPTY PROPERTIES
         # SET THE EMPTY PROPERTIES
@@ -109,7 +96,6 @@ class Trainer:
         self.ema_model = None
         self.ema_model = None
         self.sg_logger = None
         self.sg_logger = None
         self.update_param_groups = None
         self.update_param_groups = None
-        self.post_prediction_callback = None
         self.criterion = None
         self.criterion = None
         self.training_params = None
         self.training_params = None
         self.scaler = None
         self.scaler = None
@@ -144,10 +130,7 @@ class Trainer:
 
 
         # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
         # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
         self.experiment_name = experiment_name
         self.experiment_name = experiment_name
-        self.ckpt_name = ckpt_name
-        self.overwrite_local_checkpoint = overwrite_local_checkpoint
-        self.model_checkpoints_location = model_checkpoints_location
-        self._set_dataset_properties(classes, test_loader, train_loader, valid_loader)
+        self.ckpt_name = None
 
 
         # CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
         # CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
         if ckpt_root_dir:
         if ckpt_root_dir:
@@ -161,7 +144,6 @@ class Trainer:
         # INITIALIZE THE DEVICE FOR THE MODEL
         # INITIALIZE THE DEVICE FOR THE MODEL
         self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
         self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
 
 
-        self.post_prediction_callback = post_prediction_callback
         # SET THE DEFAULTS
         # SET THE DEFAULTS
         # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
         # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
 
 
@@ -221,25 +203,18 @@ class Trainer:
                       valid_loader=val_dataloader,
                       valid_loader=val_dataloader,
                       training_params=cfg.training_hyperparams)
                       training_params=cfg.training_hyperparams)
 
 
-    def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
-        if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
-            raise IllegalDataloaderInitialization()
-
-        dataset_params = {"batch_size": train_loader.batch_size if train_loader else None,
-                          "val_batch_size": valid_loader.batch_size if valid_loader else None,
-                          "test_batch_size": test_loader.batch_size if test_loader else None,
-                          "dataset_dir": None,
-                          "s3_link": None}
-
-        if train_loader and self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
-            if not all([isinstance(train_loader.sampler, DistributedSampler),
-                        isinstance(valid_loader.sampler, DistributedSampler),
-                        test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
-                logger.warning(
-                    "DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
-
-        self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
-            HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
+    def _set_dataset_params(self):
+        self.dataset_params = {
+            "train_dataset_params": self.train_loader.dataset.dataset_params if hasattr(self.train_loader.dataset,
+                                                                                        "dataset_params") else None,
+            "train_dataloader_params": self.train_loader.dataloader_params if hasattr(self.train_loader,
+                                                                                      "dataloader_params") else None,
+            "valid_dataset_params": self.valid_loader.dataset.dataset_params if hasattr(self.valid_loader.dataset,
+                                                                                        "dataset_params") else None,
+            "valid_dataloader_params": self.valid_loader.dataloader_params if hasattr(self.valid_loader,
+                                                                                      "dataloader_params") else None
+        }
+        self.dataset_params = HpmStruct(**self.dataset_params)
 
 
     def _set_ckpt_loading_attributes(self):
     def _set_ckpt_loading_attributes(self):
         """
         """
@@ -408,8 +383,7 @@ class Trainer:
                                                                source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                                source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                                metric_to_watch=self.metric_to_watch,
                                                                metric_to_watch=self.metric_to_watch,
                                                                metric_idx=self.metric_idx_in_results_tuple,
                                                                metric_idx=self.metric_idx_in_results_tuple,
-                                                               load_checkpoint=self.load_checkpoint,
-                                                               model_checkpoints_location=self.model_checkpoints_location)
+                                                               load_checkpoint=self.load_checkpoint)
 
 
     def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
     def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
         """
         """
@@ -520,6 +494,7 @@ class Trainer:
         self.load_ema_as_net = False
         self.load_ema_as_net = False
         self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
         self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
         self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
         self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
+        self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", 'ckpt_latest.pth')
         self._load_checkpoint_to_model()
         self._load_checkpoint_to_model()
 
 
     def _init_arch_params(self):
     def _init_arch_params(self):
@@ -546,6 +521,21 @@ class Trainer:
             :param train_loader: Dataloader for train set.
             :param train_loader: Dataloader for train set.
             :param valid_loader: Dataloader for validation.
             :param valid_loader: Dataloader for validation.
             :param training_params:
             :param training_params:
+
+                - `resume` : bool (default=False)
+
+                    Whether to continue training from ckpt with the same experiment name
+                     (i.e resume from CKPT_ROOT_DIR/EXPERIMENT_NAME/CKPT_NAME)
+
+                - `ckpt_name` : str (default=ckpt_latest.pth)
+
+                    The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and
+                     resume_path=None
+
+                - `resume_path`: str (default=None)
+
+                    Explicit checkpoint path (.pth file) to use to resume training.
+
                 - `max_epochs` : int
                 - `max_epochs` : int
 
 
                     Number of epochs to run training.
                     Number of epochs to run training.
@@ -859,6 +849,7 @@ class Trainer:
 
 
         self.train_loader = train_loader or self.train_loader
         self.train_loader = train_loader or self.train_loader
         self.valid_loader = valid_loader or self.valid_loader
         self.valid_loader = valid_loader or self.valid_loader
+        self._set_dataset_params()
 
 
         self.training_params = TrainingParams()
         self.training_params = TrainingParams()
         self.training_params.override(**training_params)
         self.training_params.override(**training_params)
@@ -1120,10 +1111,6 @@ class Trainer:
             self.phase_callback_handler(Phase.POST_TRAINING, context)
             self.phase_callback_handler(Phase.POST_TRAINING, context)
 
 
             if not self.ddp_silent_mode:
             if not self.ddp_silent_mode:
-                if self.model_checkpoints_location != 'local':
-                    logger.info('[CLEANUP] - Saving Checkpoint files')
-                    self.sg_logger.upload()
-
                 self.sg_logger.close()
                 self.sg_logger.close()
 
 
     def _reset_best_metric(self):
     def _reset_best_metric(self):
@@ -1367,10 +1354,7 @@ class Trainer:
             ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
             ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                   experiment_name=self.experiment_name,
                                                   experiment_name=self.experiment_name,
                                                   ckpt_name=self.ckpt_name,
                                                   ckpt_name=self.ckpt_name,
-                                                  model_checkpoints_location=self.model_checkpoints_location,
-                                                  external_checkpoint_path=self.external_checkpoint_path,
-                                                  overwrite_local_checkpoint=self.overwrite_local_checkpoint,
-                                                  load_weights_only=self.load_weights_only)
+                                                  external_checkpoint_path=self.external_checkpoint_path)
 
 
             # LOAD CHECKPOINT TO MODEL
             # LOAD CHECKPOINT TO MODEL
             self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
             self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
@@ -1390,7 +1374,7 @@ class Trainer:
         self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
         self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
         self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
         self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
 
 
-    def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None, post_prediction_callback=None,
+    def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None,
                        test_metrics_list=None,
                        test_metrics_list=None,
                        loss_logging_items_names=None, test_phase_callbacks=None):
                        loss_logging_items_names=None, test_phase_callbacks=None):
         """Run commands that are common to all models"""
         """Run commands that are common to all models"""
@@ -1400,7 +1384,6 @@ class Trainer:
         # IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
         # IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
         self.test_loader = test_loader or self.test_loader
         self.test_loader = test_loader or self.test_loader
         self.criterion = loss or self.criterion
         self.criterion = loss or self.criterion
-        self.post_prediction_callback = post_prediction_callback or self.post_prediction_callback
         self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
         self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
         self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
         self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
 
 
@@ -1445,7 +1428,7 @@ class Trainer:
 
 
         # OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
         # OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
         general_sg_logger_params = {'experiment_name': self.experiment_name,
         general_sg_logger_params = {'experiment_name': self.experiment_name,
-                                    'storage_location': self.model_checkpoints_location,
+                                    'storage_location': 'local',
                                     'resumed': self.load_checkpoint,
                                     'resumed': self.load_checkpoint,
                                     'training_params': self.training_params,
                                     'training_params': self.training_params,
                                     'checkpoints_dir_path': self.checkpoints_dir_path}
                                     'checkpoints_dir_path': self.checkpoints_dir_path}
Discard
@@ -10,8 +10,7 @@ except (ModuleNotFoundError, ImportError, NameError):
     from torch.hub import _download_url_to_file as download_url_to_file
     from torch.hub import _download_url_to_file as download_url_to_file
 
 
 
 
-def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str,
-                        overwrite_local_checkpoint: bool, load_weights_only: bool):
+def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
     """
     """
     Gets the local path to the checkpoint file, which will be:
     Gets the local path to the checkpoint file, which will be:
         - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
         - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
@@ -24,30 +23,11 @@ def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt
     @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
     @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
     @param experiment_name: experiment name attr in trainer
     @param experiment_name: experiment name attr in trainer
     @param ckpt_name: checkpoint filename
     @param ckpt_name: checkpoint filename
-    @param model_checkpoints_location: S3, local ot URL
     @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
     @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
-    @param overwrite_local_checkpoint: whether to overwrite the checkpoint file with the same name when downloading from S3.
-    @param load_weights_only: whether to load the network's state dict only.
     @return:
     @return:
     """
     """
     source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
     source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
-    if model_checkpoints_location == 'local':
-        ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
-
-    # COPY THE DATA FROM 'S3'/'URL' INTO A LOCAL DIRECTORY
-    elif model_checkpoints_location.startswith('s3') or model_checkpoints_location == 'url':
-        # COPY REMOTE DATA TO A LOCAL DIRECTORY AND GET THAT DIRECTORYs NAME
-        ckpt_local_path = copy_ckpt_to_local_folder(local_ckpt_destination_dir=experiment_name,
-                                                    ckpt_filename=ckpt_name,
-                                                    remote_ckpt_source_dir=source_ckpt_folder_name,
-                                                    path_src=model_checkpoints_location,
-                                                    overwrite_local_ckpt=overwrite_local_checkpoint,
-                                                    load_weights_only=load_weights_only)
-
-    else:
-        # ERROR IN USER CODE FLOW - THIS WILL EVENTUALLY RAISE AN EXCEPTION
-        raise NotImplementedError(
-            'model_checkpoints_data_source: ' + str(model_checkpoints_location) + 'not supported')
+    ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
 
 
     return ckpt_local_path
     return ckpt_local_path
 
 
Discard
@@ -19,7 +19,6 @@ class ModelWeightAveraging:
                  source_ckpt_folder_name=None, metric_to_watch='acc',
                  source_ckpt_folder_name=None, metric_to_watch='acc',
                  metric_idx=1, load_checkpoint=False,
                  metric_idx=1, load_checkpoint=False,
                  number_of_models_to_average=10,
                  number_of_models_to_average=10,
-                 model_checkpoints_location='local'
                  ):
                  ):
         """
         """
         Init the ModelWeightAveraging
         Init the ModelWeightAveraging
@@ -45,7 +44,6 @@ class ModelWeightAveraging:
                                                                   source_ckpt_folder_name=source_ckpt_folder_name,
                                                                   source_ckpt_folder_name=source_ckpt_folder_name,
                                                                   ckpt_filename="averaging_snapshots.pkl",
                                                                   ckpt_filename="averaging_snapshots.pkl",
                                                                   load_weights_only=False,
                                                                   load_weights_only=False,
-                                                                  model_checkpoints_location=model_checkpoints_location,
                                                                   overwrite_local_ckpt=True)
                                                                   overwrite_local_ckpt=True)
 
 
         else:
         else:
Discard
@@ -16,7 +16,7 @@ from super_gradients.training.dataloaders.dataloaders import (
 class TestCifarTrainer(unittest.TestCase):
 class TestCifarTrainer(unittest.TestCase):
     def test_train_cifar10_dataloader(self):
     def test_train_cifar10_dataloader(self):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        trainer = Trainer("test", model_checkpoints_location="local")
+        trainer = Trainer("test")
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
         cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
         trainer.train(
         trainer.train(
@@ -35,7 +35,7 @@ class TestCifarTrainer(unittest.TestCase):
 
 
     def test_train_cifar100_dataloader(self):
     def test_train_cifar100_dataloader(self):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
-        trainer = Trainer("test", model_checkpoints_location="local")
+        trainer = Trainer("test")
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
         cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
         model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
         model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
         trainer.train(
         trainer.train(
Discard
@@ -38,7 +38,7 @@ class TestTrainer(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=''):
     def get_classification_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18", num_classes=5)
         model = models.get("resnet18", num_classes=5)
         return trainer, model
         return trainer, model
 
 
Discard
@@ -69,7 +69,7 @@ class ConversionCallbackTest(unittest.TestCase):
                 "phase_callbacks": phase_callbacks,
                 "phase_callbacks": phase_callbacks,
             }
             }
 
 
-            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
+            trainer = Trainer(f"{architecture}_example",
                               ckpt_root_dir=checkpoint_dir)
                               ckpt_root_dir=checkpoint_dir)
             model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             try:
             try:
@@ -102,7 +102,7 @@ class ConversionCallbackTest(unittest.TestCase):
 
 
         for architecture in SEMANTIC_SEGMENTATION:
         for architecture in SEMANTIC_SEGMENTATION:
             model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
             model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
-            trainer = Trainer(f"{architecture}_example", model_checkpoints_location="local",
+            trainer = Trainer(f"{architecture}_example",
                               ckpt_root_dir=checkpoint_dir)
                               ckpt_root_dir=checkpoint_dir)
             model = models.get(model_name=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
             model = models.get(model_name=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
 
 
Discard
@@ -10,7 +10,7 @@ from deci_lab_client.models import Metric, QuantizationLevel, ModelMetadata, Opt
 
 
 class DeciLabUploadTest(unittest.TestCase):
 class DeciLabUploadTest(unittest.TestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
-        self.trainer = Trainer("deci_lab_export_test_model", model_checkpoints_location='local')
+        self.trainer = Trainer("deci_lab_export_test_model")
 
 
     def test_train_with_deci_lab_integration(self):
     def test_train_with_deci_lab_integration(self):
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
         model_meta_data = ModelMetadata(name='model_for_deci_lab_upload_test',
Discard
@@ -23,7 +23,7 @@ class CallWrapper:
 class EMAIntegrationTest(unittest.TestCase):
 class EMAIntegrationTest(unittest.TestCase):
 
 
     def _init_model(self) -> None:
     def _init_model(self) -> None:
-        self.trainer = Trainer("resnet18_cifar_ema_test", model_checkpoints_location='local',
+        self.trainer = Trainer("resnet18_cifar_ema_test",
                                device='cpu', multi_gpu=MultiGPUMode.OFF)
                                device='cpu', multi_gpu=MultiGPUMode.OFF)
         self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
         self.model = models.get("resnet18_cifar", arch_params={"num_classes": 5})
 
 
Discard
@@ -30,7 +30,7 @@ class LRTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_trainer(name=''):
     def get_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18_cifar", num_classes=5)
         model = models.get("resnet18_cifar", num_classes=5)
         return trainer, model
         return trainer, model
 
 
Discard
@@ -235,7 +235,7 @@ class PretrainedModelsTest(unittest.TestCase):
         }
         }
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet50', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params)
                            **self.imagenet_pretrained_ckpt_params)
@@ -244,7 +244,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet50"], delta=0.001)
 
 
     def test_transfer_learning_resnet50_imagenet(self):
     def test_transfer_learning_resnet50_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         model = models.get("resnet50", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -253,7 +253,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_resnet34_imagenet(self):
     def test_pretrained_resnet34_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet34', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet34',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -263,7 +263,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet34"], delta=0.001)
 
 
     def test_transfer_learning_resnet34_imagenet(self):
     def test_transfer_learning_resnet34_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet34_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         model = models.get("resnet34", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -272,7 +272,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_resnet18_imagenet(self):
     def test_pretrained_resnet18_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet18', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet18',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
@@ -282,7 +282,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["resnet18"], delta=0.001)
 
 
     def test_transfer_learning_resnet18_imagenet(self):
     def test_transfer_learning_resnet18_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet18_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
         model = models.get("resnet18", arch_params=self.imagenet_pretrained_arch_params["resnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -291,7 +291,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY800', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -301,7 +301,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY800"], delta=0.001)
 
 
     def test_transfer_learning_regnetY800_imagenet(self):
     def test_transfer_learning_regnetY800_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY800", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -310,7 +310,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_regnetY600_imagenet(self):
     def test_pretrained_regnetY600_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY600', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY600',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -320,7 +320,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY600"], delta=0.001)
 
 
     def test_transfer_learning_regnetY600_imagenet(self):
     def test_transfer_learning_regnetY600_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY600_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY600", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -329,7 +329,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_regnetY400_imagenet(self):
     def test_pretrained_regnetY400_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY400', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY400',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -339,7 +339,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY400"], delta=0.001)
 
 
     def test_transfer_learning_regnetY400_imagenet(self):
     def test_transfer_learning_regnetY400_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY400_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY400", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -348,7 +348,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_regnetY200_imagenet(self):
     def test_pretrained_regnetY200_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY200', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY200',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
@@ -358,7 +358,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["regnetY200"], delta=0.001)
 
 
     def test_transfer_learning_regnetY200_imagenet(self):
     def test_transfer_learning_regnetY200_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY200_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
         model = models.get("regnetY200", arch_params=self.imagenet_pretrained_arch_params["regnet"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -367,7 +367,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_repvgg_a0', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
@@ -377,7 +377,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["repvgg_a0"], delta=0.001)
 
 
     def test_transfer_learning_repvgg_a0_imagenet(self):
     def test_transfer_learning_repvgg_a0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
         model = models.get("repvgg_a0", arch_params=self.imagenet_pretrained_arch_params["repvgg_a0"],
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
                            **self.imagenet_pretrained_ckpt_params, num_classes=5)
@@ -386,7 +386,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_regseg48_cityscapes(self):
     def test_pretrained_regseg48_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_regseg48', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_regseg48',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -396,7 +396,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["regseg48"], delta=0.001)
 
 
     def test_transfer_learning_regseg48_cityscapes(self):
     def test_transfer_learning_regseg48_cityscapes(self):
-        trainer = Trainer('regseg48_cityscapes_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('regseg48_cityscapes_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
         model = models.get("regseg48", arch_params=self.cityscapes_pretrained_arch_params["regseg48"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -405,7 +405,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       training_params=self.regseg_transfer_segmentation_train_params)
                       training_params=self.regseg_transfer_segmentation_train_params)
 
 
     def test_pretrained_ddrnet23_cityscapes(self):
     def test_pretrained_ddrnet23_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -415,7 +415,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23"], delta=0.001)
 
 
     def test_pretrained_ddrnet23_slim_cityscapes(self):
     def test_pretrained_ddrnet23_slim_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -425,7 +425,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["ddrnet_23_slim"], delta=0.001)
 
 
     def test_transfer_learning_ddrnet23_cityscapes(self):
     def test_transfer_learning_ddrnet23_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         model = models.get("ddrnet_23", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -434,7 +434,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
                       valid_loader=self.transfer_segmentation_dataset)
 
 
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
     def test_transfer_learning_ddrnet23_slim_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_ddrnet23_slim_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
         model = models.get("ddrnet_23_slim", arch_params=self.cityscapes_pretrained_arch_params["ddrnet_23"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -443,7 +443,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
                       valid_loader=self.transfer_segmentation_dataset)
 
 
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
     def test_pretrained_coco_segmentation_subclass_pretrained_shelfnet34_lw(self):
-        trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw', model_checkpoints_location='local',
+        trainer = Trainer('coco_segmentation_subclass_pretrained_shelfnet34_lw',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("shelfnet34_lw",
         model = models.get("shelfnet34_lw",
                            arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
                            arch_params=self.coco_segmentation_subclass_pretrained_arch_params["shelfnet34_lw"],
@@ -453,7 +453,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_segmentation_subclass_pretrained_mious["shelfnet34_lw"], delta=0.001)
 
 
     def test_pretrained_efficientnet_b0_imagenet(self):
     def test_pretrained_efficientnet_b0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_efficientnet_b0', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_efficientnet_b0',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
@@ -463,7 +463,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["efficientnet_b0"], delta=0.001)
 
 
     def test_transfer_learning_efficientnet_b0_imagenet(self):
     def test_transfer_learning_efficientnet_b0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_efficientnet_b0_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
         model = models.get("efficientnet_b0", arch_params=self.imagenet_pretrained_arch_params["efficientnet_b0"],
@@ -473,7 +473,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
     def test_pretrained_ssd_lite_mobilenet_v2_coco(self):
-        trainer = Trainer('coco_ssd_lite_mobilenet_v2', model_checkpoints_location='local',
+        trainer = Trainer('coco_ssd_lite_mobilenet_v2',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ssd_lite_mobilenet_v2",
         model = models.get("ssd_lite_mobilenet_v2",
                            arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"],
                            arch_params=self.coco_pretrained_arch_params["ssd_lite_mobilenet_v2"],
@@ -485,7 +485,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
     def test_transfer_learning_ssd_lite_mobilenet_v2_coco(self):
         trainer = Trainer('coco_ssd_lite_mobilenet_v2_transfer_learning',
         trainer = Trainer('coco_ssd_lite_mobilenet_v2_transfer_learning',
-                          model_checkpoints_location='local', multi_gpu=MultiGPUMode.OFF)
+                          multi_gpu=MultiGPUMode.OFF)
         transfer_arch_params = self.coco_pretrained_arch_params['ssd_lite_mobilenet_v2'].copy()
         transfer_arch_params = self.coco_pretrained_arch_params['ssd_lite_mobilenet_v2'].copy()
         transfer_arch_params['num_classes'] = 5
         transfer_arch_params['num_classes'] = 5
         model = models.get("ssd_lite_mobilenet_v2",
         model = models.get("ssd_lite_mobilenet_v2",
@@ -496,11 +496,10 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_detection_dataset)
                       valid_loader=self.transfer_detection_dataset)
 
 
     def test_pretrained_ssd_mobilenet_v1_coco(self):
     def test_pretrained_ssd_mobilenet_v1_coco(self):
-        trainer = Trainer('coco_ssd_mobilenet_v1', model_checkpoints_location='local',
+        trainer = Trainer('coco_ssd_mobilenet_v1',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("ssd_mobilenet_v1",
         model = models.get("ssd_mobilenet_v1",
-                           arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"],
-                           **self.coco_pretrained_ckpt_params)
+                           arch_params=self.coco_pretrained_arch_params["coco_ssd_mobilenet_v1"], **self.coco_pretrained_ckpt_params)
         ssd_post_prediction_callback = SSDPostPredictCallback()
         ssd_post_prediction_callback = SSDPostPredictCallback()
         res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'],
         res = trainer.test(model=model, test_loader=self.coco_dataset['ssd_mobilenet'],
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback,
                            test_metrics_list=[DetectionMetrics(post_prediction_callback=ssd_post_prediction_callback,
@@ -509,7 +508,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["coco_ssd_mobilenet_v1"], delta=0.001)
 
 
     def test_pretrained_yolox_s_coco(self):
     def test_pretrained_yolox_s_coco(self):
-        trainer = Trainer('yolox_s', model_checkpoints_location='local',
+        trainer = Trainer('yolox_s',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("yolox_s",
         model = models.get("yolox_s",
@@ -521,7 +520,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_s"], delta=0.001)
 
 
     def test_pretrained_yolox_m_coco(self):
     def test_pretrained_yolox_m_coco(self):
-        trainer = Trainer('yolox_m', model_checkpoints_location='local',
+        trainer = Trainer('yolox_m',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_m",
         model = models.get("yolox_m",
                            **self.coco_pretrained_ckpt_params)
                            **self.coco_pretrained_ckpt_params)
@@ -532,7 +531,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_m"], delta=0.001)
 
 
     def test_pretrained_yolox_l_coco(self):
     def test_pretrained_yolox_l_coco(self):
-        trainer = Trainer('yolox_l', model_checkpoints_location='local',
+        trainer = Trainer('yolox_l',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_l",
         model = models.get("yolox_l",
                            **self.coco_pretrained_ckpt_params)
                            **self.coco_pretrained_ckpt_params)
@@ -543,7 +542,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_l"], delta=0.001)
 
 
     def test_pretrained_yolox_n_coco(self):
     def test_pretrained_yolox_n_coco(self):
-        trainer = Trainer('yolox_n', model_checkpoints_location='local',
+        trainer = Trainer('yolox_n',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("yolox_n",
         model = models.get("yolox_n",
@@ -555,7 +554,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
         self.assertAlmostEqual(res, self.coco_pretrained_maps["yolox_n"], delta=0.001)
 
 
     def test_pretrained_yolox_t_coco(self):
     def test_pretrained_yolox_t_coco(self):
-        trainer = Trainer('yolox_t', model_checkpoints_location='local',
+        trainer = Trainer('yolox_t',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_t",
         model = models.get("yolox_t",
                            **self.coco_pretrained_ckpt_params)
                            **self.coco_pretrained_ckpt_params)
@@ -567,7 +566,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_yolox_n_coco(self):
     def test_transfer_learning_yolox_n_coco(self):
         trainer = Trainer('test_transfer_learning_yolox_n_coco',
         trainer = Trainer('test_transfer_learning_yolox_n_coco',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
         model = models.get("yolox_n", **self.coco_pretrained_ckpt_params, num_classes=5)
         trainer.train(model=model, training_params=self.transfer_detection_train_params_yolox,
         trainer.train(model=model, training_params=self.transfer_detection_train_params_yolox,
@@ -576,7 +575,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
     def test_transfer_learning_mobilenet_v3_large_imagenet(self):
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_large_transfer_learning',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -586,7 +585,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_mobilenet_v3_large_imagenet(self):
     def test_pretrained_mobilenet_v3_large_imagenet(self):
-        trainer = Trainer('imagenet_mobilenet_v3_large', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v3_large',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         model = models.get("mobilenet_v3_large", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -597,7 +596,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
     def test_transfer_learning_mobilenet_v3_small_imagenet(self):
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_small_transfer_learning',
         trainer = Trainer('imagenet_pretrained_mobilenet_v3_small_transfer_learning',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -607,7 +606,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_mobilenet_v3_small_imagenet(self):
     def test_pretrained_mobilenet_v3_small_imagenet(self):
-        trainer = Trainer('imagenet_mobilenet_v3_small', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v3_small',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         model = models.get("mobilenet_v3_small", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -618,7 +617,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_mobilenet_v2_imagenet(self):
     def test_transfer_learning_mobilenet_v2_imagenet(self):
         trainer = Trainer('imagenet_pretrained_mobilenet_v2_transfer_learning',
         trainer = Trainer('imagenet_pretrained_mobilenet_v2_transfer_learning',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -628,7 +627,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_mobilenet_v2_imagenet(self):
     def test_pretrained_mobilenet_v2_imagenet(self):
-        trainer = Trainer('imagenet_mobilenet_v2', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_mobilenet_v2',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
         model = models.get("mobilenet_v2", arch_params=self.imagenet_pretrained_arch_params["mobilenet"],
@@ -638,7 +637,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["mobilenet_v2"], delta=0.001)
 
 
     def test_pretrained_stdc1_seg50_cityscapes(self):
     def test_pretrained_stdc1_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg50',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -648,7 +647,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg50"], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
     def test_transfer_learning_stdc1_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg50_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc1_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -657,7 +656,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
                       valid_loader=self.transfer_segmentation_dataset)
 
 
     def test_pretrained_stdc1_seg75_cityscapes(self):
     def test_pretrained_stdc1_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg75',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -667,7 +666,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc1_seg75"], delta=0.001)
 
 
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
     def test_transfer_learning_stdc1_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc1_seg75_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc1_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -676,7 +675,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
                       valid_loader=self.transfer_segmentation_dataset)
 
 
     def test_pretrained_stdc2_seg50_cityscapes(self):
     def test_pretrained_stdc2_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg50',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -686,7 +685,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg50"], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
     def test_transfer_learning_stdc2_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg50_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc2_seg50", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -695,7 +694,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_segmentation_dataset)
                       valid_loader=self.transfer_segmentation_dataset)
 
 
     def test_pretrained_stdc2_seg75_cityscapes(self):
     def test_pretrained_stdc2_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg75',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -705,7 +704,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["stdc2_seg75"], delta=0.001)
 
 
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
     def test_transfer_learning_stdc2_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_stdc2_seg75_transfer_learning',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
         model = models.get("stdc2_seg75", arch_params=self.cityscapes_pretrained_arch_params["stdc"],
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
                            **self.cityscapes_pretrained_ckpt_params, num_classes=5)
@@ -715,7 +714,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_vit_base_imagenet21k(self):
     def test_transfer_learning_vit_base_imagenet21k(self):
         trainer = Trainer('imagenet21k_pretrained_vit_base',
         trainer = Trainer('imagenet21k_pretrained_vit_base',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -726,7 +725,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_vit_large_imagenet21k(self):
     def test_transfer_learning_vit_large_imagenet21k(self):
         trainer = Trainer('imagenet21k_pretrained_vit_large',
         trainer = Trainer('imagenet21k_pretrained_vit_large',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -736,7 +735,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_vit_base_imagenet(self):
     def test_pretrained_vit_base_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_vit_base', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_vit_base',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         model = models.get("vit_base", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                            **self.imagenet_pretrained_ckpt_params)
                            **self.imagenet_pretrained_ckpt_params)
@@ -746,7 +745,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_base"], delta=0.001)
 
 
     def test_pretrained_vit_large_imagenet(self):
     def test_pretrained_vit_large_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_vit_large', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_vit_large',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         model = models.get("vit_large", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                            **self.imagenet_pretrained_ckpt_params)
                            **self.imagenet_pretrained_ckpt_params)
@@ -756,7 +755,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
         self.assertAlmostEqual(res, self.imagenet_pretrained_accuracies["vit_large"], delta=0.001)
 
 
     def test_pretrained_beit_base_imagenet(self):
     def test_pretrained_beit_base_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_beit_base', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_beit_base',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
                            **self.imagenet_pretrained_ckpt_params)
                            **self.imagenet_pretrained_ckpt_params)
@@ -767,7 +766,7 @@ class PretrainedModelsTest(unittest.TestCase):
 
 
     def test_transfer_learning_beit_base_imagenet(self):
     def test_transfer_learning_beit_base_imagenet(self):
         trainer = Trainer('test_transfer_learning_beit_base_imagenet',
         trainer = Trainer('test_transfer_learning_beit_base_imagenet',
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
 
 
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
         model = models.get("beit_base_patch16_224", arch_params=self.imagenet_pretrained_arch_params["vit_base"],
@@ -777,7 +776,7 @@ class PretrainedModelsTest(unittest.TestCase):
                       valid_loader=self.transfer_classification_dataloader)
                       valid_loader=self.transfer_classification_dataloader)
 
 
     def test_pretrained_pplite_t_seg50_cityscapes(self):
     def test_pretrained_pplite_t_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_t_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_t_seg50',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
         model = models.get("pp_lite_t_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -789,7 +788,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg50"], delta=0.001)
 
 
     def test_pretrained_pplite_t_seg75_cityscapes(self):
     def test_pretrained_pplite_t_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_t_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_t_seg75',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
         model = models.get("pp_lite_t_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -801,7 +800,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_t_seg75"], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg50_cityscapes(self):
     def test_pretrained_pplite_b_seg50_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_b_seg50', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_b_seg50',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
         model = models.get("pp_lite_b_seg50", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
@@ -813,7 +812,7 @@ class PretrainedModelsTest(unittest.TestCase):
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
         self.assertAlmostEqual(res, self.cityscapes_pretrained_mious["pp_lite_b_seg50"], delta=0.001)
 
 
     def test_pretrained_pplite_b_seg75_cityscapes(self):
     def test_pretrained_pplite_b_seg75_cityscapes(self):
-        trainer = Trainer('cityscapes_pretrained_pplite_b_seg75', model_checkpoints_location='local',
+        trainer = Trainer('cityscapes_pretrained_pplite_b_seg75',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
         model = models.get("pp_lite_b_seg75", arch_params=self.cityscapes_pretrained_arch_params["pplite_seg"],
                            **self.cityscapes_pretrained_ckpt_params)
                            **self.cityscapes_pretrained_ckpt_params)
Discard
@@ -10,7 +10,7 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
 class QATIntegrationTest(unittest.TestCase):
 class QATIntegrationTest(unittest.TestCase):
     def _get_trainer(self, experiment_name):
     def _get_trainer(self, experiment_name):
         trainer = Trainer(experiment_name,
         trainer = Trainer(experiment_name,
-                          model_checkpoints_location='local',
+
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet18", pretrained_weights="imagenet")
         model = models.get("resnet18", pretrained_weights="imagenet")
         return trainer, model
         return trainer, model
Discard
@@ -18,9 +18,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
         """
         """
         # Create dataset
         # Create dataset
 
 
-        trainer = Trainer('dataset_statistics_visual_test',
-                          model_checkpoints_location='local',
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer('dataset_statistics_visual_test')
 
 
         model = models.get("yolox_s")
         model = models.get("yolox_s")
 
 
Discard
@@ -12,10 +12,9 @@ class TestDetectionUtils(unittest.TestCase):
     def test_visualization(self):
     def test_visualization(self):
 
 
         # Create Yolo model
         # Create Yolo model
-        trainer = Trainer('visualization_test',
-                          model_checkpoints_location='local',
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer('visualization_test')
         model = models.get("yolox_n", pretrained_weights="coco")
         model = models.get("yolox_n", pretrained_weights="coco")
+        post_prediction_callback = YoloPostPredictionCallback()
 
 
         # Simulate one iteration of validation subset
         # Simulate one iteration of validation subset
         valid_loader = coco2017_val()
         valid_loader = coco2017_val()
@@ -23,7 +22,7 @@ class TestDetectionUtils(unittest.TestCase):
         imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
         imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
         targets = core_utils.tensor_container_to_device(targets, trainer.device)
         targets = core_utils.tensor_container_to_device(targets, trainer.device)
         output = model(imgs)
         output = model(imgs)
-        output = trainer.post_prediction_callback(output)
+        output = post_prediction_callback(output)
         # Visualize the batch
         # Visualize the batch
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
                                                COCO_DETECTION_CLASSES_LIST, trainer.checkpoints_dir_path)
                                                COCO_DETECTION_CLASSES_LIST, trainer.checkpoints_dir_path)
Discard
@@ -58,7 +58,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=min metric, test that training stops after no improvement in metric value for amount of `patience`
         Test for mode=min metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         epochs.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -80,7 +80,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=max metric, test that training stops after no improvement in metric value for amount of `patience`
         Test for mode=max metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         epochs.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
                                    verbose=True)
                                    verbose=True)
         phase_callbacks = [early_stop_acc]
         phase_callbacks = [early_stop_acc]
@@ -101,7 +101,7 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
         Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", threshold=0.1, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", threshold=0.1, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -121,7 +121,7 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
         Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
                                    verbose=True)
                                    verbose=True)
@@ -144,7 +144,7 @@ class EarlyStopTest(unittest.TestCase):
         Test that training stops when monitor value is not a finite number. Test case of NaN and Inf values.
         Test that training stops when monitor value is not a finite number. Test case of NaN and Inf values.
         """
         """
         # test Nan value
         # test Nan value
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", check_finite=True,
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", check_finite=True,
                                     verbose=True)
                                     verbose=True)
@@ -162,7 +162,7 @@ class EarlyStopTest(unittest.TestCase):
         self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
         self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
 
 
         # test Inf value
         # test Inf value
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -183,7 +183,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for `min_delta` argument, metric value is considered an improvement only if
         Test for `min_delta` argument, metric value is considered an improvement only if
         current_value - min_delta > best_value
         current_value - min_delta > best_value
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
                                    min_delta=0.1, verbose=True)
                                    min_delta=0.1, verbose=True)
Discard
@@ -11,7 +11,7 @@ from super_gradients.training.metrics import Accuracy, Top5
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
 
 
     def test_training_with_factories(self):
     def test_training_with_factories(self):
-        trainer = Trainer("test_train_with_factories", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_factories")
         net = models.get("resnet18", num_classes=5)
         net = models.get("resnet18", num_classes=5)
         train_params = {"max_epochs": 2,
         train_params = {"max_epochs": 2,
                         "lr_updates": [1],
                         "lr_updates": [1],
Discard
@@ -6,7 +6,6 @@ from super_gradients import Trainer
 import torch
 import torch
 from torch.utils.data import TensorDataset, DataLoader
 from torch.utils.data import TensorDataset, DataLoader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
-from super_gradients.training.exceptions.sg_trainer_exceptions import IllegalDataloaderInitialization
 
 
 
 
 class InitializeWithDataloadersTest(unittest.TestCase):
 class InitializeWithDataloadersTest(unittest.TestCase):
@@ -26,27 +25,8 @@ class InitializeWithDataloadersTest(unittest.TestCase):
         label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
         label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
 
 
-    def test_initialization_rules(self):
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, test_loader=self.testcase_testloader,
-                classes=self.testcase_classes)
-
     def test_train_with_dataloaders(self):
     def test_train_with_dataloaders(self):
-        trainer = Trainer(experiment_name="test_name", model_checkpoints_location="local")
+        trainer = Trainer(experiment_name="test_name")
         model = models.get("resnet18", num_classes=5)
         model = models.get("resnet18", num_classes=5)
         trainer.train(model=model,
         trainer.train(model=model,
                       training_params={"max_epochs": 2,
                       training_params={"max_epochs": 2,
Discard
@@ -29,7 +29,7 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
     def test_ema_ckpt_reload(self):
     def test_ema_ckpt_reload(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("ema_ckpt_test")
         trainer.train(model=net, training_params=self.train_params,
         trainer.train(model=net, training_params=self.train_params,
                       train_loader=classification_test_dataloader(),
                       train_loader=classification_test_dataloader(),
                       valid_loader=classification_test_dataloader())
                       valid_loader=classification_test_dataloader())
@@ -38,7 +38,7 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
 
 
         # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
         # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("ema_ckpt_test")
 
 
         net_collector = PreTrainingEMANetCollector()
         net_collector = PreTrainingEMANetCollector()
         self.train_params["resume"] = True
         self.train_params["resume"] = True
Discard
@@ -10,7 +10,7 @@ class LRCooldownTest(unittest.TestCase):
     def test_lr_cooldown_with_lr_scheduling(self):
     def test_lr_cooldown_with_lr_scheduling(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
@@ -38,7 +38,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup(self):
     def test_lr_warmup(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -60,7 +60,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup_with_lr_scheduling(self):
     def test_lr_warmup_with_lr_scheduling(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -85,7 +85,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_warmup_initial_lr(self):
     def test_warmup_initial_lr(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("test_warmup_initial_lr", model_checkpoints_location='local')
+        trainer = Trainer("test_warmup_initial_lr")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -107,7 +107,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_custom_lr_warmup(self):
     def test_custom_lr_warmup(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("custom_lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("custom_lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
@@ -12,7 +12,7 @@ from torchmetrics import MetricCollection
 
 
 class PhaseContextTest(unittest.TestCase):
 class PhaseContextTest(unittest.TestCase):
     def context_information_in_train_test(self):
     def context_information_in_train_test(self):
-        trainer = Trainer("context_information_in_train_test", model_checkpoints_location='local')
+        trainer = Trainer("context_information_in_train_test")
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
 
 
Discard
@@ -31,7 +31,7 @@ class ContextMethodsCheckerCallback(PhaseCallback):
 class ContextMethodsTest(unittest.TestCase):
 class ContextMethodsTest(unittest.TestCase):
     def test_access_to_methods_by_phase(self):
     def test_access_to_methods_by_phase(self):
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("test_access_to_methods_by_phase", model_checkpoints_location='local')
+        trainer = Trainer("test_access_to_methods_by_phase")
 
 
         phase_callbacks = []
         phase_callbacks = []
         for phase in Phase:
         for phase in Phase:
Discard
@@ -14,21 +14,21 @@ class PretrainedModelsUnitTest(unittest.TestCase):
         self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
         self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet50_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet50", pretrained_weights="imagenet")
         model = models.get("resnet50", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY800_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY800", pretrained_weights="imagenet")
         model = models.get("regnetY800", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
Discard
@@ -18,7 +18,7 @@ class SaveCkptListUnitTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
 
 
         # Define Model
         # Define Model
-        trainer = Trainer("save_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("save_ckpt_test")
 
 
         # Build Model
         # Build Model
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
Discard
@@ -53,7 +53,7 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(cls.change_state_dict_keys(cls.original_torch_model.state_dict()), cls.checkpoint_diff_keys_path)
         torch.save(cls.change_state_dict_keys(cls.original_torch_model.state_dict()), cls.checkpoint_diff_keys_path)
 
 
         # Save the model's state_dict checkpoint in Trainer format
         # Save the model's state_dict checkpoint in Trainer format
-        cls.trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')  # Saves in /checkpoints
+        cls.trainer = Trainer("load_checkpoint_test")  # Saves in /checkpoints
         cls.trainer.set_net(cls.original_torch_model)
         cls.trainer.set_net(cls.original_torch_model)
         # FIXME: after uniting init and build_model we should remove this
         # FIXME: after uniting init and build_model we should remove this
         cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
         cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
Discard
@@ -6,9 +6,9 @@ from super_gradients.training.dataloaders.dataloaders import classification_test
     detection_test_dataloader, segmentation_test_dataloader
     detection_test_dataloader, segmentation_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training import MultiGPUMode, models
 from super_gradients.training import MultiGPUMode, models
-from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
+from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 
 
 
 
 class TestWithoutTrainTest(unittest.TestCase):
 class TestWithoutTrainTest(unittest.TestCase):
@@ -26,22 +26,21 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=''):
     def get_classification_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18", num_classes=5)
         model = models.get("resnet18", num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=''):
     def get_detection_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local',
-                          multi_gpu=MultiGPUMode.OFF,
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer(name,
+                          multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_s", num_classes=5)
         model = models.get("yolox_s", num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=''):
     def get_segmentation_trainer(name=''):
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
-        trainer = Trainer(name, model_checkpoints_location='local', multi_gpu=False)
+        trainer = Trainer(name)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         return trainer, model
         return trainer, model
 
 
@@ -52,7 +51,7 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
         trainer, model = self.get_detection_trainer(self.folder_names[1])
         trainer, model = self.get_detection_trainer(self.folder_names[1])
 
 
-        test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
+        test_metrics = [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=5)]
 
 
         assert isinstance(trainer.test(model=model, silent_mode=True,
         assert isinstance(trainer.test(model=model, silent_mode=True,
                                        test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), tuple)
                                        test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), tuple)
Discard
@@ -11,7 +11,7 @@ import shutil
 
 
 class SgTrainerLoggingTest(unittest.TestCase):
 class SgTrainerLoggingTest(unittest.TestCase):
     def test_train_logging(self):
     def test_train_logging(self):
-        trainer = Trainer("test_train_with_full_log", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_full_log")
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
Discard
@@ -19,7 +19,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
     """
     """
 
 
     def test_train_with_external_criterion(self):
     def test_train_with_external_criterion(self):
-        trainer = Trainer("external_criterion_test", model_checkpoints_location='local')
+        trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -33,7 +33,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_external_optimizer(self):
     def test_train_with_external_optimizer(self):
-        trainer = Trainer("external_optimizer_test", model_checkpoints_location='local')
+        trainer = Trainer("external_optimizer_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -47,7 +47,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_external_scheduler(self):
     def test_train_with_external_scheduler(self):
-        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
@@ -66,7 +66,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         self.assertTrue(lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1)
         self.assertTrue(lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1)
 
 
     def test_train_with_external_scheduler_class(self):
     def test_train_with_external_scheduler_class(self):
-        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -81,7 +81,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_reduce_on_plateau(self):
     def test_train_with_reduce_on_plateau(self):
-        trainer = Trainer("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_reduce_on_plateau_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
@@ -101,7 +101,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         self.assertTrue(lr_scheduler._last_lr[0] == lr * 0.1)
         self.assertTrue(lr_scheduler._last_lr[0] == lr * 0.1)
 
 
     def test_train_with_external_metric(self):
     def test_train_with_external_metric(self):
-        trainer = Trainer("external_metric_test", model_checkpoints_location='local')
+        trainer = Trainer("external_metric_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -114,7 +114,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_external_dataloaders(self):
     def test_train_with_external_dataloaders(self):
-        trainer = Trainer("external_data_loader_test", model_checkpoints_location='local')
+        trainer = Trainer("external_data_loader_test")
 
 
         batch_size = 5
         batch_size = 5
         trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))),
         trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))),
Discard
@@ -12,7 +12,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
     """
     """
 
 
     def test_train_with_precise_bn_explicit_size(self):
     def test_train_with_precise_bn_explicit_size(self):
-        trainer = Trainer("test_train_with_precise_bn_explicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_explicit_size")
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
@@ -26,7 +26,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
                       valid_loader=classification_test_dataloader(batch_size=10))
                       valid_loader=classification_test_dataloader(batch_size=10))
 
 
     def test_train_with_precise_bn_implicit_size(self):
     def test_train_with_precise_bn_implicit_size(self):
-        trainer = Trainer("test_train_with_precise_bn_implicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_implicit_size")
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
Discard
@@ -29,7 +29,7 @@ class UpdateParamGroupsTest(unittest.TestCase):
     def test_lr_scheduling_with_update_param_groups(self):
     def test_lr_scheduling_with_update_param_groups(self):
         # Define Model
         # Define Model
         net = TestNet()
         net = TestNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
Discard