|
@@ -2,14 +2,14 @@ import inspect
|
|
import os
|
|
import os
|
|
import sys
|
|
import sys
|
|
from copy import deepcopy
|
|
from copy import deepcopy
|
|
-from typing import Union, Tuple, Mapping, List, Any
|
|
|
|
|
|
+from typing import Union, Tuple, Mapping
|
|
|
|
|
|
import hydra
|
|
import hydra
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
from omegaconf import DictConfig
|
|
from omegaconf import DictConfig
|
|
from torch import nn
|
|
from torch import nn
|
|
-from torch.utils.data import DataLoader, DistributedSampler
|
|
|
|
|
|
+from torch.utils.data import DataLoader
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
from torchmetrics import MetricCollection
|
|
from torchmetrics import MetricCollection
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
@@ -33,14 +33,12 @@ from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
|
|
from super_gradients.training.utils import sg_trainer_utils
|
|
from super_gradients.training.utils import sg_trainer_utils
|
|
from super_gradients.training.utils.quantization_utils import QATCallback
|
|
from super_gradients.training.utils.quantization_utils import QATCallback
|
|
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
|
|
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args
|
|
-from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, \
|
|
|
|
- IllegalDataloaderInitialization, GPUModeNotSetupError
|
|
|
|
|
|
+from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
|
|
from super_gradients.training.losses import LOSSES
|
|
from super_gradients.training.losses import LOSSES
|
|
from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
|
|
from super_gradients.training.metrics.metric_utils import get_metrics_titles, get_metrics_results_tuple, \
|
|
get_logging_values, \
|
|
get_logging_values, \
|
|
get_metrics_dict, get_train_loop_description_dict
|
|
get_metrics_dict, get_train_loop_description_dict
|
|
from super_gradients.training.params import TrainingParams
|
|
from super_gradients.training.params import TrainingParams
|
|
-from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
|
|
|
|
from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
|
|
from super_gradients.training.utils.distributed_training_utils import MultiGPUModeAutocastWrapper, \
|
|
reduce_results_tuple_for_ddp, compute_precise_bn_stats, setup_gpu_mode, require_gpu_setup
|
|
reduce_results_tuple_for_ddp, compute_precise_bn_stats, setup_gpu_mode, require_gpu_setup
|
|
from super_gradients.training.utils.ema import ModelEMA
|
|
from super_gradients.training.utils.ema import ModelEMA
|
|
@@ -76,30 +74,19 @@ class Trainer:
|
|
returns the test loss, accuracy and runtime
|
|
returns the test loss, accuracy and runtime
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
|
|
|
|
- model_checkpoints_location: str = 'local',
|
|
|
|
- overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth',
|
|
|
|
- post_prediction_callback: DetectionPostPredictionCallback = None, ckpt_root_dir: str = None,
|
|
|
|
- train_loader: DataLoader = None, valid_loader: DataLoader = None, test_loader: DataLoader = None,
|
|
|
|
- classes: List[Any] = None):
|
|
|
|
|
|
+ def __init__(self, experiment_name: str, device: str = None,
|
|
|
|
+ multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF,
|
|
|
|
+ ckpt_root_dir: str = None):
|
|
"""
|
|
"""
|
|
|
|
|
|
:param experiment_name: Used for logging and loading purposes
|
|
:param experiment_name: Used for logging and loading purposes
|
|
:param device: If equal to 'cpu' runs on the CPU otherwise on GPU
|
|
:param device: If equal to 'cpu' runs on the CPU otherwise on GPU
|
|
:param multi_gpu: If True, runs on all available devices
|
|
:param multi_gpu: If True, runs on all available devices
|
|
- :param model_checkpoints_location: If set to 's3' saves the Checkpoints in AWS S3
|
|
|
|
otherwise saves the Checkpoints Locally
|
|
otherwise saves the Checkpoints Locally
|
|
- :param overwrite_local_checkpoint: If set to False keeps the current local checkpoint when importing
|
|
|
|
checkpoint from cloud service, otherwise overwrites the local checkpoints file
|
|
checkpoint from cloud service, otherwise overwrites the local checkpoints file
|
|
- :param ckpt_name: The Checkpoint to Load
|
|
|
|
:param ckpt_root_dir: Local root directory path where all experiment logging directories will
|
|
:param ckpt_root_dir: Local root directory path where all experiment logging directories will
|
|
reside. When none is give, it is assumed that
|
|
reside. When none is give, it is assumed that
|
|
pkg_resources.resource_filename('checkpoints', "") exists and will be used.
|
|
pkg_resources.resource_filename('checkpoints', "") exists and will be used.
|
|
- :param train_loader: Training set Dataloader instead of using DatasetInterface, must pass "valid_loader"
|
|
|
|
- and "classes" along with it
|
|
|
|
- :param valid_loader: Validation set Dataloader
|
|
|
|
- :param test_loader: Test set Dataloader
|
|
|
|
- :param classes: List of class labels
|
|
|
|
|
|
|
|
"""
|
|
"""
|
|
# SET THE EMPTY PROPERTIES
|
|
# SET THE EMPTY PROPERTIES
|
|
@@ -109,7 +96,6 @@ class Trainer:
|
|
self.ema_model = None
|
|
self.ema_model = None
|
|
self.sg_logger = None
|
|
self.sg_logger = None
|
|
self.update_param_groups = None
|
|
self.update_param_groups = None
|
|
- self.post_prediction_callback = None
|
|
|
|
self.criterion = None
|
|
self.criterion = None
|
|
self.training_params = None
|
|
self.training_params = None
|
|
self.scaler = None
|
|
self.scaler = None
|
|
@@ -144,10 +130,7 @@ class Trainer:
|
|
|
|
|
|
# SETTING THE PROPERTIES FROM THE CONSTRUCTOR
|
|
# SETTING THE PROPERTIES FROM THE CONSTRUCTOR
|
|
self.experiment_name = experiment_name
|
|
self.experiment_name = experiment_name
|
|
- self.ckpt_name = ckpt_name
|
|
|
|
- self.overwrite_local_checkpoint = overwrite_local_checkpoint
|
|
|
|
- self.model_checkpoints_location = model_checkpoints_location
|
|
|
|
- self._set_dataset_properties(classes, test_loader, train_loader, valid_loader)
|
|
|
|
|
|
+ self.ckpt_name = None
|
|
|
|
|
|
# CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
|
|
# CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION
|
|
if ckpt_root_dir:
|
|
if ckpt_root_dir:
|
|
@@ -161,7 +144,6 @@ class Trainer:
|
|
# INITIALIZE THE DEVICE FOR THE MODEL
|
|
# INITIALIZE THE DEVICE FOR THE MODEL
|
|
self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
|
|
self._initialize_device(requested_device=device, requested_multi_gpu=multi_gpu)
|
|
|
|
|
|
- self.post_prediction_callback = post_prediction_callback
|
|
|
|
# SET THE DEFAULTS
|
|
# SET THE DEFAULTS
|
|
# TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
|
|
# TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK
|
|
|
|
|
|
@@ -221,25 +203,18 @@ class Trainer:
|
|
valid_loader=val_dataloader,
|
|
valid_loader=val_dataloader,
|
|
training_params=cfg.training_hyperparams)
|
|
training_params=cfg.training_hyperparams)
|
|
|
|
|
|
- def _set_dataset_properties(self, classes, test_loader, train_loader, valid_loader):
|
|
|
|
- if any([train_loader, valid_loader, classes]) and not all([train_loader, valid_loader, classes]):
|
|
|
|
- raise IllegalDataloaderInitialization()
|
|
|
|
-
|
|
|
|
- dataset_params = {"batch_size": train_loader.batch_size if train_loader else None,
|
|
|
|
- "val_batch_size": valid_loader.batch_size if valid_loader else None,
|
|
|
|
- "test_batch_size": test_loader.batch_size if test_loader else None,
|
|
|
|
- "dataset_dir": None,
|
|
|
|
- "s3_link": None}
|
|
|
|
-
|
|
|
|
- if train_loader and self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
|
|
|
|
- if not all([isinstance(train_loader.sampler, DistributedSampler),
|
|
|
|
- isinstance(valid_loader.sampler, DistributedSampler),
|
|
|
|
- test_loader is None or isinstance(test_loader.sampler, DistributedSampler)]):
|
|
|
|
- logger.warning(
|
|
|
|
- "DDP training was selected but the dataloader samplers are not of type DistributedSamplers")
|
|
|
|
-
|
|
|
|
- self.dataset_params, self.train_loader, self.valid_loader, self.test_loader, self.classes = \
|
|
|
|
- HpmStruct(**dataset_params), train_loader, valid_loader, test_loader, classes
|
|
|
|
|
|
+ def _set_dataset_params(self):
|
|
|
|
+ self.dataset_params = {
|
|
|
|
+ "train_dataset_params": self.train_loader.dataset.dataset_params if hasattr(self.train_loader.dataset,
|
|
|
|
+ "dataset_params") else None,
|
|
|
|
+ "train_dataloader_params": self.train_loader.dataloader_params if hasattr(self.train_loader,
|
|
|
|
+ "dataloader_params") else None,
|
|
|
|
+ "valid_dataset_params": self.valid_loader.dataset.dataset_params if hasattr(self.valid_loader.dataset,
|
|
|
|
+ "dataset_params") else None,
|
|
|
|
+ "valid_dataloader_params": self.valid_loader.dataloader_params if hasattr(self.valid_loader,
|
|
|
|
+ "dataloader_params") else None
|
|
|
|
+ }
|
|
|
|
+ self.dataset_params = HpmStruct(**self.dataset_params)
|
|
|
|
|
|
def _set_ckpt_loading_attributes(self):
|
|
def _set_ckpt_loading_attributes(self):
|
|
"""
|
|
"""
|
|
@@ -408,8 +383,7 @@ class Trainer:
|
|
source_ckpt_folder_name=self.source_ckpt_folder_name,
|
|
source_ckpt_folder_name=self.source_ckpt_folder_name,
|
|
metric_to_watch=self.metric_to_watch,
|
|
metric_to_watch=self.metric_to_watch,
|
|
metric_idx=self.metric_idx_in_results_tuple,
|
|
metric_idx=self.metric_idx_in_results_tuple,
|
|
- load_checkpoint=self.load_checkpoint,
|
|
|
|
- model_checkpoints_location=self.model_checkpoints_location)
|
|
|
|
|
|
+ load_checkpoint=self.load_checkpoint)
|
|
|
|
|
|
def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
|
|
def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context: PhaseContext, *args, **kwargs):
|
|
"""
|
|
"""
|
|
@@ -520,6 +494,7 @@ class Trainer:
|
|
self.load_ema_as_net = False
|
|
self.load_ema_as_net = False
|
|
self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
|
|
self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
|
|
self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
|
|
self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
|
|
|
|
+ self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", 'ckpt_latest.pth')
|
|
self._load_checkpoint_to_model()
|
|
self._load_checkpoint_to_model()
|
|
|
|
|
|
def _init_arch_params(self):
|
|
def _init_arch_params(self):
|
|
@@ -546,6 +521,21 @@ class Trainer:
|
|
:param train_loader: Dataloader for train set.
|
|
:param train_loader: Dataloader for train set.
|
|
:param valid_loader: Dataloader for validation.
|
|
:param valid_loader: Dataloader for validation.
|
|
:param training_params:
|
|
:param training_params:
|
|
|
|
+
|
|
|
|
+ - `resume` : bool (default=False)
|
|
|
|
+
|
|
|
|
+ Whether to continue training from ckpt with the same experiment name
|
|
|
|
+ (i.e resume from CKPT_ROOT_DIR/EXPERIMENT_NAME/CKPT_NAME)
|
|
|
|
+
|
|
|
|
+ - `ckpt_name` : str (default=ckpt_latest.pth)
|
|
|
|
+
|
|
|
|
+ The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and
|
|
|
|
+ resume_path=None
|
|
|
|
+
|
|
|
|
+ - `resume_path`: str (default=None)
|
|
|
|
+
|
|
|
|
+ Explicit checkpoint path (.pth file) to use to resume training.
|
|
|
|
+
|
|
- `max_epochs` : int
|
|
- `max_epochs` : int
|
|
|
|
|
|
Number of epochs to run training.
|
|
Number of epochs to run training.
|
|
@@ -859,6 +849,7 @@ class Trainer:
|
|
|
|
|
|
self.train_loader = train_loader or self.train_loader
|
|
self.train_loader = train_loader or self.train_loader
|
|
self.valid_loader = valid_loader or self.valid_loader
|
|
self.valid_loader = valid_loader or self.valid_loader
|
|
|
|
+ self._set_dataset_params()
|
|
|
|
|
|
self.training_params = TrainingParams()
|
|
self.training_params = TrainingParams()
|
|
self.training_params.override(**training_params)
|
|
self.training_params.override(**training_params)
|
|
@@ -1120,10 +1111,6 @@ class Trainer:
|
|
self.phase_callback_handler(Phase.POST_TRAINING, context)
|
|
self.phase_callback_handler(Phase.POST_TRAINING, context)
|
|
|
|
|
|
if not self.ddp_silent_mode:
|
|
if not self.ddp_silent_mode:
|
|
- if self.model_checkpoints_location != 'local':
|
|
|
|
- logger.info('[CLEANUP] - Saving Checkpoint files')
|
|
|
|
- self.sg_logger.upload()
|
|
|
|
-
|
|
|
|
self.sg_logger.close()
|
|
self.sg_logger.close()
|
|
|
|
|
|
def _reset_best_metric(self):
|
|
def _reset_best_metric(self):
|
|
@@ -1367,10 +1354,7 @@ class Trainer:
|
|
ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
|
|
ckpt_local_path = get_ckpt_local_path(source_ckpt_folder_name=self.source_ckpt_folder_name,
|
|
experiment_name=self.experiment_name,
|
|
experiment_name=self.experiment_name,
|
|
ckpt_name=self.ckpt_name,
|
|
ckpt_name=self.ckpt_name,
|
|
- model_checkpoints_location=self.model_checkpoints_location,
|
|
|
|
- external_checkpoint_path=self.external_checkpoint_path,
|
|
|
|
- overwrite_local_checkpoint=self.overwrite_local_checkpoint,
|
|
|
|
- load_weights_only=self.load_weights_only)
|
|
|
|
|
|
+ external_checkpoint_path=self.external_checkpoint_path)
|
|
|
|
|
|
# LOAD CHECKPOINT TO MODEL
|
|
# LOAD CHECKPOINT TO MODEL
|
|
self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
|
|
self.checkpoint = load_checkpoint_to_model(ckpt_local_path=ckpt_local_path,
|
|
@@ -1390,7 +1374,7 @@ class Trainer:
|
|
self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
|
|
self.best_metric = self.checkpoint['acc'] if 'acc' in self.checkpoint.keys() else -1
|
|
self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
|
|
self.start_epoch = self.checkpoint['epoch'] if 'epoch' in self.checkpoint.keys() else 0
|
|
|
|
|
|
- def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None, post_prediction_callback=None,
|
|
|
|
|
|
+ def _prep_for_test(self, test_loader: torch.utils.data.DataLoader = None, loss=None,
|
|
test_metrics_list=None,
|
|
test_metrics_list=None,
|
|
loss_logging_items_names=None, test_phase_callbacks=None):
|
|
loss_logging_items_names=None, test_phase_callbacks=None):
|
|
"""Run commands that are common to all models"""
|
|
"""Run commands that are common to all models"""
|
|
@@ -1400,7 +1384,6 @@ class Trainer:
|
|
# IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
|
|
# IF SPECIFIED IN THE FUNCTION CALL - OVERRIDE THE self ARGUMENTS
|
|
self.test_loader = test_loader or self.test_loader
|
|
self.test_loader = test_loader or self.test_loader
|
|
self.criterion = loss or self.criterion
|
|
self.criterion = loss or self.criterion
|
|
- self.post_prediction_callback = post_prediction_callback or self.post_prediction_callback
|
|
|
|
self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
|
|
self.loss_logging_items_names = loss_logging_items_names or self.loss_logging_items_names
|
|
self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
|
|
self.phase_callbacks = test_phase_callbacks or self.phase_callbacks
|
|
|
|
|
|
@@ -1445,7 +1428,7 @@ class Trainer:
|
|
|
|
|
|
# OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
|
|
# OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
|
|
general_sg_logger_params = {'experiment_name': self.experiment_name,
|
|
general_sg_logger_params = {'experiment_name': self.experiment_name,
|
|
- 'storage_location': self.model_checkpoints_location,
|
|
|
|
|
|
+ 'storage_location': 'local',
|
|
'resumed': self.load_checkpoint,
|
|
'resumed': self.load_checkpoint,
|
|
'training_params': self.training_params,
|
|
'training_params': self.training_params,
|
|
'checkpoints_dir_path': self.checkpoints_dir_path}
|
|
'checkpoints_dir_path': self.checkpoints_dir_path}
|