Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
@@ -25,7 +25,6 @@ from super_gradients.training.utils.quantization_utils import PostQATConversionC
 super_gradients.init_trainer()
 super_gradients.init_trainer()
 
 
 trainer = Trainer("resnet18_qat_example",
 trainer = Trainer("resnet18_qat_example",
-                  model_checkpoints_location='local',
                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
                   multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
 
 
 train_loader = dataloaders.imagenet_train()
 train_loader = dataloaders.imagenet_train()
Discard
@@ -22,7 +22,7 @@ training_hyperparams:
   resume: ${resume}
   resume: ${resume}
 
 
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 architecture: resnet18_cifar
 architecture: resnet18_cifar
Discard
@@ -78,7 +78,7 @@ checkpoint_params:
 
 
 experiment_name: ${architecture}_cityscapes
 experiment_name: ${architecture}_cityscapes
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu: DDP
 multi_gpu: DDP
Discard
@@ -44,7 +44,7 @@ arch_params:
   strict_load: no_key_matching
   strict_load: no_key_matching
 
 
 load_checkpoint: False
 load_checkpoint: False
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 resume: False
 resume: False
Discard
@@ -24,7 +24,7 @@ checkpoint_params:
 architecture: stdc1_seg
 architecture: stdc1_seg
 experiment_name: ${architecture}_cityscapes
 experiment_name: ${architecture}_cityscapes
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu: DDP
 multi_gpu: DDP
Discard
@@ -34,7 +34,7 @@ val_dataloader: coco2017_val
 architecture: ssd_lite_mobilenet_v2
 architecture: ssd_lite_mobilenet_v2
 
 
 data_loader_num_workers: 8
 data_loader_num_workers: 8
-model_checkpoints_location: local
+
 experiment_suffix: res${dataset_params.train_image_size}
 experiment_suffix: res${dataset_params.train_image_size}
 experiment_name: ${architecture}_coco_${experiment_suffix}
 experiment_name: ${architecture}_coco_${experiment_suffix}
 
 
Discard
@@ -40,7 +40,7 @@ defaults:
 train_dataloader: coco2017_train
 train_dataloader: coco2017_train
 val_dataloader: coco2017_val
 val_dataloader: coco2017_val
 
 
-model_checkpoints_location: local
+
 
 
 load_checkpoint: False
 load_checkpoint: False
 resume: False
 resume: False
Discard
@@ -9,7 +9,7 @@
 #   0. Make sure that the data is stored in dataset_params.dataset_dir or add "dataset_params.data_dir=<PATH-TO-DATASET>" at the end of the command below (feel free to check ReadMe)
 #   0. Make sure that the data is stored in dataset_params.dataset_dir or add "dataset_params.data_dir=<PATH-TO-DATASET>" at the end of the command below (feel free to check ReadMe)
 #   1. Move to the project root (where you will find the ReadMe and src folder)
 #   1. Move to the project root (where you will find the ReadMe and src folder)
 #   2. Run the command:
 #   2. Run the command:
-#       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco_segmentation_shelfnet_lw --model_checkpoints_location=<checkpoint-location>
+#       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco_segmentation_shelfnet_lw
 
 
 
 
 # /!\ THIS RECIPE IS NOT SUPPORTED AT THE MOMENT /!\
 # /!\ THIS RECIPE IS NOT SUPPORTED AT THE MOMENT /!\
Discard
@@ -27,7 +27,7 @@ training_hyperparams:
 
 
 experiment_name: efficientnet_b0_imagenet
 experiment_name: efficientnet_b0_imagenet
 
 
-model_checkpoints_location: local
+
 ckpt_root_dir:
 ckpt_root_dir:
 
 
 multi_gpu: DDP
 multi_gpu: DDP
Discard
@@ -25,7 +25,7 @@ arch_params:
 
 
 data_loader_num_workers: 8
 data_loader_num_workers: 8
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -8,7 +8,7 @@ defaults:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -36,7 +36,7 @@ arch_params:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 load_checkpoint: False
 load_checkpoint: False
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
Discard
@@ -25,7 +25,7 @@ train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -24,7 +24,7 @@ arch_params:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
   resume: ${resume}
   resume: ${resume}
Discard
@@ -66,7 +66,7 @@ student_checkpoint_params:
   pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
   pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").
 
 
 
 
-model_checkpoints_location: local
+
 
 
 
 
 run_teacher_on_eval: True
 run_teacher_on_eval: True
Discard
@@ -21,7 +21,7 @@ defaults:
 train_dataloader: imagenet_train
 train_dataloader: imagenet_train
 val_dataloader: imagenet_val
 val_dataloader: imagenet_val
 
 
-model_checkpoints_location: local
+
 
 
 resume: False
 resume: False
 training_hyperparams:
 training_hyperparams:
Discard
@@ -1,4 +1,6 @@
 resume: False # whether to continue training from ckpt with the same experiment name.
 resume: False # whether to continue training from ckpt with the same experiment name.
+resume_path: # Explicit checkpoint path (.pth file) to use to resume training.
+ckpt_name: ckpt_latest.pth  # The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and resume_path=None
 lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
 lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
 lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
 lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
 lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
 lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
Discard
@@ -9,7 +9,7 @@ from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.sg_trainer import Trainer
 from super_gradients.training.sg_trainer import Trainer
-from typing import Union, List, Any
+from typing import Union
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
@@ -20,7 +20,6 @@ from super_gradients.training.exceptions.kd_trainer_exceptions import Architectu
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     TeacherKnowledgeException, UndefinedNumClassesException
     TeacherKnowledgeException, UndefinedNumClassesException
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 from super_gradients.training.utils.sg_trainer_utils import parse_args
 
 
@@ -29,15 +28,8 @@ logger = get_logger(__name__)
 
 
 class KDTrainer(Trainer):
 class KDTrainer(Trainer):
     def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
     def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True,
-                 ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
-                 train_loader: DataLoader = None,
-                 valid_loader: DataLoader = None, test_loader: DataLoader = None, classes: List[Any] = None):
-
-        super().__init__(experiment_name, device, multi_gpu, model_checkpoints_location, overwrite_local_checkpoint,
-                         ckpt_name, post_prediction_callback,
-                         ckpt_root_dir, train_loader, valid_loader, test_loader, classes)
+                 ckpt_root_dir: str = None):
+        super().__init__(experiment_name=experiment_name, device=device, multi_gpu=multi_gpu, ckpt_root_dir=ckpt_root_dir)
         self.student_architecture = None
         self.student_architecture = None
         self.teacher_architecture = None
         self.teacher_architecture = None
         self.student_arch_params = None
         self.student_arch_params = None
Discard
@@ -63,6 +63,7 @@ DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
                            },
                            },
                            "resume": False,
                            "resume": False,
                            "resume_path": None,
                            "resume_path": None,
+                           "ckpt_name": 'ckpt_latest.pth',
                            "resume_strict_load": False
                            "resume_strict_load": False
                            }
                            }
 
 
Discard
@@ -2,14 +2,14 @@ import inspect
 import os
 import os
 import sys
 import sys
 from copy import deepcopy
 from copy import deepcopy
-from typing import Union, Tuple, Mapping, List, Any
+from typing import Union, Tuple, Mapping
 
 
 import hydra
 import hydra
 import numpy as np
 import numpy as np
 import torch
 import torch
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 from torch import nn
 from torch import nn
-from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data import DataLoader
 from torch.cuda.amp import GradScaler, autocast
 from torch.cuda.amp import GradScaler, autocast
 from torchmetrics import MetricCollection
 from torchmetrics import MetricCollection
 from tqdm import tqdm
 from tqdm import tqdm
@@ -33,14 +33,12 @@ from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils.quantization_utils import QATCallback
 from super_gradients.training.utils.quantization_utils import QATCallback
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
-from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, \
-    IllegalDataloaderInitialization, GPUModeNotSetupError
+from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
 from super_gradients.training.losses import LOSSES
 from super_gradients.training.losses import LOSSES
 from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
 from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
     get_logging_values, \
     get_logging_values, \
     get_metrics_dict, get_train_loop_description_dict
     get_metrics_dict, get_train_loop_description_dict
 from super_gradients.training.params import TrainingParams
 from super_gradients.training.params import TrainingParams
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
 from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
     reduce_results_tuple_for_ddp, compute_precise_bn_stats, setup_gpu_mode, require_gpu_setup
     reduce_results_tuple_for_ddp, compute_precise_bn_stats, setup_gpu_mode, require_gpu_setup
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.ema import ModelEMA
@@ -76,30 +74,19 @@ class Trainer:
         returns the test loss, accuracy and runtime
         returns the test loss, accuracy and runtime
     """
     """
 
 
-    def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
-                 model_checkpoints_location: str = 'local',
-                 overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
-                 post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
-                 train_loader: DataLoader = None, valid_loader: DataLoader = None, test_loader: DataLoader = None,
-                 classes: List[Any] = None):
+    def __init__(self, experiment_name: str, device: str = None,
+                 multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
+                 ckpt_root_dir: str = None):
         """
         """
 
 
         :param experiment_name:                      Used for logging and loading purposes
         :param experiment_name:                      Used for logging and loading purposes
         :param device:                          If equal to 'cpu' runs on the CPU otherwise on GPU
         :param device:                          If equal to 'cpu' runs on the CPU otherwise on GPU
         :param multi_gpu:                       If True, runs on all available devices
         :param multi_gpu:                       If True, runs on all available devices
-        :param model_checkpoints_location:      If set to 's3' saves the Checkpoints in AWS S3
                                                 otherwise saves the Checkpoints Locally
                                                 otherwise saves the Checkpoints Locally
-        :param overwrite_local_checkpoint:      If set to False keeps the current local checkpoint when importing
                                                 checkpoint from cloud service, otherwise overwrites the local checkpoints file
                                                 checkpoint from cloud service, otherwise overwrites the local checkpoints file
-        :param ckpt_name:                       The Checkpoint to Load
         :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will
         :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will
                                                 reside. When none is give, it is assumed that
                                                 reside. When none is give, it is assumed that
                                                 pkg_resources.resource_filename('checkpoints', "") exists and will be used.
                                                 pkg_resources.resource_filename('checkpoints', "") exists and will be used.
-        :param train_loader:                    Training set Dataloader instead of using DatasetInterface, must pass "valid_loader"
-                                                and "classes" along with it
-        :param valid_loader:                    Validation set Dataloader
-        :param test_loader:                     Test set Dataloader
-        :param classes:                         List of class labels
 
 
         """
         """
         # SET THE EMPTY PROPERTIES
         # SET THE EMPTY PROPERTIES
@@ -109,7 +96,6 @@ class Trainer:
         self.ema_model = None
         self.ema_model = None
         self.sg_logger = None
         self.sg_logger = None
         self.update_param_groups = None
         self.update_param_groups = None
-        self.post_prediction_callback = None
         self.criterion = None
         self.criterion = None
         self.training_params = None
         self.training_params = None
         self.scaler = None
         self.scaler = None
@@ -144,10 +130,7 @@ class Trainer:
 
 
         # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
         # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
         self.experiment_name = experiment_name
         self.experiment_name = experiment_name
-        self.ckpt_name = ckpt_name
-        self.overwrite_local_checkpoint = overwrite_local_checkpoint
-        self.model_checkpoints_location = model_checkpoints_location
-        self._set_dataset_properties(classes, test_loader, train_loader, valid_loader)
+        self.ckpt_name = None
 
 
         # CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
         # CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
         if ckpt_root_dir:
         if ckpt_root_dir:
@@ -161,7 +144,6 @@ class Trainer:
         # INITIALIZE THE DEVICE FOR THE MODEL
         # INITIALIZE THE DEVICE FOR THE MODEL
         self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
         self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
 
 
-        self.post_prediction_callback = post_prediction_callback
         # SET THE DEFAULTS
         # SET THE DEFAULTS
         # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
         # TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
 
 
@@ -221,25 +203,18 @@ class Trainer:
                       valid_loader=val_dataloader,
                       valid_loader=val_dataloader,
                       training_params=cfg.training_hyperparams)
                       training_params=cfg.training_hyperparams)
 
 
-    def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
-        if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
-            raise IllegalDataloaderInitialization()
-
-        dataset_params = {"batch_size": train_loader.batch_size if train_loader else None,
-                          "val_batch_size": valid_loader.batch_size if valid_loader else None,
-                          "test_batch_size": test_loader.batch_size if test_loader else None,
-                          "dataset_dir": None,
-                          "s3_link": None}
-
-        if train_loader and self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
-            if not all([isinstance(train_loader.sampler, DistributedSampler),
-                        isinstance(valid_loader.sampler, DistributedSampler),
-                        test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
-                logger.warning(
-                    "DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
-
-        self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
-            HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
+    def _set_dataset_params(self):
+        self.dataset_params = {
+            "train_dataset_params": self.train_loader.dataset.dataset_params if hasattr(self.train_loader.dataset,
+                                                                                        "dataset_params") else None,
+            "train_dataloader_params": self.train_loader.dataloader_params if hasattr(self.train_loader,
+                                                                                      "dataloader_params") else None,
+            "valid_dataset_params": self.valid_loader.dataset.dataset_params if hasattr(self.valid_loader.dataset,
+                                                                                        "dataset_params") else None,
+            "valid_dataloader_params": self.valid_loader.dataloader_params if hasattr(self.valid_loader,
+                                                                                      "dataloader_params") else None
+        }
+        self.dataset_params = HpmStruct(**self.dataset_params)
 
 
     def _set_ckpt_loading_attributes(self):
     def _set_ckpt_loading_attributes(self):
         """
         """
@@ -408,8 +383,7 @@ class Trainer:
                                                                source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                                source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                                metric_to_watch=self.metric_to_watch,
                                                                metric_to_watch=self.metric_to_watch,
                                                                metric_idx=self.metric_idx_in_results_tuple,
                                                                metric_idx=self.metric_idx_in_results_tuple,
-                                                               load_checkpoint=self.load_checkpoint,
-                                                               model_checkpoints_location=self.model_checkpoints_location)
+                                                               load_checkpoint=self.load_checkpoint)
 
 
     def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
     def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
         """
         """
@@ -520,6 +494,7 @@ class Trainer:
         self.load_ema_as_net = False
         self.load_ema_as_net = False
         self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
         self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
         self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
         self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
+        self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", 'ckpt_latest.pth')
         self._load_checkpoint_to_model()
         self._load_checkpoint_to_model()
 
 
     def _init_arch_params(self):
     def _init_arch_params(self):
@@ -546,6 +521,21 @@ class Trainer:
             :param train_loader: Dataloader for train set.
             :param train_loader: Dataloader for train set.
             :param valid_loader: Dataloader for validation.
             :param valid_loader: Dataloader for validation.
             :param training_params:
             :param training_params:
+
+                - `resume` : bool (default=False)
+
+                    Whether to continue training from ckpt with the same experiment name
+                     (i.e resume from CKPT_ROOT_DIR/EXPERIMENT_NAME/CKPT_NAME)
+
+                - `ckpt_name` : str (default=ckpt_latest.pth)
+
+                    The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and
+                     resume_path=None
+
+                - `resume_path`: str (default=None)
+
+                    Explicit checkpoint path (.pth file) to use to resume training.
+
                 - `max_epochs` : int
                 - `max_epochs` : int
 
 
                     Number of epochs to run training.
                     Number of epochs to run training.
@@ -859,6 +849,7 @@ class Trainer:
 
 
         self.train_loader = train_loader or self.train_loader
         self.train_loader = train_loader or self.train_loader
         self.valid_loader = valid_loader or self.valid_loader
         self.valid_loader = valid_loader or self.valid_loader
+        self._set_dataset_params()
 
 
         self.training_params = TrainingParams()
         self.training_params = TrainingParams()
         self.training_params.override(**training_params)
         self.training_params.override(**training_params)
@@ -1120,10 +1111,6 @@ class Trainer:
             self.phase_callback_handler(Phase.POST_TRAINING, context)
             self.phase_callback_handler(Phase.POST_TRAINING, context)
 
 
             if not self.ddp_silent_mode:
             if not self.ddp_silent_mode:
-                if self.model_checkpoints_location != 'local':
-                    logger.info('[CLEANUP] - Saving Checkpoint files')
-                    self.sg_logger.upload()
-
                 self.sg_logger.close()
                 self.sg_logger.close()
 
 
     def _reset_best_metric(self):
     def _reset_best_metric(self):
@@ -1367,10 +1354,7 @@ class Trainer:
             ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
             ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
                                                   experiment_name=self.experiment_name,
                                                   experiment_name=self.experiment_name,
                                                   ckpt_name=self.ckpt_name,
                                                   ckpt_name=self.ckpt_name,
-                                                  model_checkpoints_location=self.model_checkpoints_location,
-                                                  external_checkpoint_path=self.external_checkpoint_path,
-                                                  overwrite_local_checkpoint=self.overwrite_local_checkpoint,
-                                                  load_weights_only=self.load_weights_only)
+                                                  external_checkpoint_path=self.external_checkpoint_path)
 
 
             # LOAD CHECKPOINT TO MODEL
             # LOAD CHECKPOINT TO MODEL
             self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
             self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
@@ -1390,7 +1374,7 @@ class Trainer:
         self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
         self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
         self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
         self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
 
 
-    def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None, post_prediction_callback=None,
+    def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None,
                        test_metrics_list=None,
                        test_metrics_list=None,
                        loss_logging_items_names=None, test_phase_callbacks=None):
                        loss_logging_items_names=None, test_phase_callbacks=None):
         """Run commands that are common to all models"""
         """Run commands that are common to all models"""
@@ -1400,7 +1384,6 @@ class Trainer:
         # IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
         # IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
         self.test_loader = test_loader or self.test_loader
         self.test_loader = test_loader or self.test_loader
         self.criterion = loss or self.criterion
         self.criterion = loss or self.criterion
-        self.post_prediction_callback = post_prediction_callback or self.post_prediction_callback
         self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
         self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
         self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
         self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
 
 
@@ -1445,7 +1428,7 @@ class Trainer:
 
 
         # OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
         # OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
         general_sg_logger_params = {'experiment_name': self.experiment_name,
         general_sg_logger_params = {'experiment_name': self.experiment_name,
-                                    'storage_location': self.model_checkpoints_location,
+                                    'storage_location': 'local',
                                     'resumed': self.load_checkpoint,
                                     'resumed': self.load_checkpoint,
                                     'training_params': self.training_params,
                                     'training_params': self.training_params,
                                     'checkpoints_dir_path': self.checkpoints_dir_path}
                                     'checkpoints_dir_path': self.checkpoints_dir_path}
Discard
@@ -10,8 +10,7 @@ except (ModuleNotFoundError, ImportError, NameError):
     from torch.hub import _download_url_to_file as download_url_to_file
     from torch.hub import _download_url_to_file as download_url_to_file
 
 
 
 
-def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str,
-                        overwrite_local_checkpoint: bool, load_weights_only: bool):
+def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
     """
     """
     Gets the local path to the checkpoint file, which will be:
     Gets the local path to the checkpoint file, which will be:
         - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
         - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
@@ -24,30 +23,11 @@ def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt
     @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
     @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
     @param experiment_name: experiment name attr in trainer
     @param experiment_name: experiment name attr in trainer
     @param ckpt_name: checkpoint filename
     @param ckpt_name: checkpoint filename
-    @param model_checkpoints_location: S3, local ot URL
     @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
     @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
-    @param overwrite_local_checkpoint: whether to overwrite the checkpoint file with the same name when downloading from S3.
-    @param load_weights_only: whether to load the network's state dict only.
     @return:
     @return:
     """
     """
     source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
     source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
-    if model_checkpoints_location == 'local':
-        ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
-
-    # COPY THE DATA FROM 'S3'/'URL' INTO A LOCAL DIRECTORY
-    elif model_checkpoints_location.startswith('s3') or model_checkpoints_location == 'url':
-        # COPY REMOTE DATA TO A LOCAL DIRECTORY AND GET THAT DIRECTORYs NAME
-        ckpt_local_path = copy_ckpt_to_local_folder(local_ckpt_destination_dir=experiment_name,
-                                                    ckpt_filename=ckpt_name,
-                                                    remote_ckpt_source_dir=source_ckpt_folder_name,
-                                                    path_src=model_checkpoints_location,
-                                                    overwrite_local_ckpt=overwrite_local_checkpoint,
-                                                    load_weights_only=load_weights_only)
-
-    else:
-        # ERROR IN USER CODE FLOW - THIS WILL EVENTUALLY RAISE AN EXCEPTION
-        raise NotImplementedError(
-            'model_checkpoints_data_source: ' + str(model_checkpoints_location) + 'not supported')
+    ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
 
 
     return ckpt_local_path
     return ckpt_local_path
 
 
Discard
@@ -19,7 +19,6 @@ class ModelWeightAveraging:
                  source_ckpt_folder_name=None, metric_to_watch='acc',
                  source_ckpt_folder_name=None, metric_to_watch='acc',
                  metric_idx=1, load_checkpoint=False,
                  metric_idx=1, load_checkpoint=False,
                  number_of_models_to_average=10,
                  number_of_models_to_average=10,
-                 model_checkpoints_location='local'
                  ):
                  ):
         """
         """
         Init the ModelWeightAveraging
         Init the ModelWeightAveraging
@@ -45,7 +44,6 @@ class ModelWeightAveraging:
                                                                   source_ckpt_folder_name=source_ckpt_folder_name,
                                                                   source_ckpt_folder_name=source_ckpt_folder_name,
                                                                   ckpt_filename="averaging_snapshots.pkl",
                                                                   ckpt_filename="averaging_snapshots.pkl",
                                                                   load_weights_only=False,
                                                                   load_weights_only=False,
-                                                                  model_checkpoints_location=model_checkpoints_location,
                                                                   overwrite_local_ckpt=True)
                                                                   overwrite_local_ckpt=True)
 
 
         else:
         else:
Discard