|
@@ -1,15 +1,16 @@
|
|
import hydra
|
|
import hydra
|
|
import torch.nn
|
|
import torch.nn
|
|
-from omegaconf import DictConfig
|
|
|
|
|
|
+from omegaconf import DictConfig, OmegaConf
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
+from super_gradients.training.utils.distributed_training_utils import setup_device
|
|
from super_gradients.common import MultiGPUMode
|
|
from super_gradients.common import MultiGPUMode
|
|
from super_gradients.training.dataloaders import dataloaders
|
|
from super_gradients.training.dataloaders import dataloaders
|
|
from super_gradients.training.models import SgModule
|
|
from super_gradients.training.models import SgModule
|
|
from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
|
|
from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
|
|
from super_gradients.training.models.kd_modules.kd_module import KDModule
|
|
from super_gradients.training.models.kd_modules.kd_module import KDModule
|
|
from super_gradients.training.sg_trainer import Trainer
|
|
from super_gradients.training.sg_trainer import Trainer
|
|
-from typing import Union
|
|
|
|
|
|
+from typing import Union, Dict
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.training import utils as core_utils, models
|
|
from super_gradients.training import utils as core_utils, models
|
|
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
|
|
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
|
|
@@ -25,7 +26,6 @@ from super_gradients.training.exceptions.kd_trainer_exceptions import (
|
|
)
|
|
)
|
|
from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
|
|
from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
|
|
from super_gradients.training.utils.ema import KDModelEMA
|
|
from super_gradients.training.utils.ema import KDModelEMA
|
|
-from super_gradients.training.utils.sg_trainer_utils import parse_args
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -47,11 +47,16 @@ class KDTrainer(Trainer):
|
|
@return: output of kd_trainer.train(...) (i.e results tuple)
|
|
@return: output of kd_trainer.train(...) (i.e results tuple)
|
|
"""
|
|
"""
|
|
# INSTANTIATE ALL OBJECTS IN CFG
|
|
# INSTANTIATE ALL OBJECTS IN CFG
|
|
- cfg = hydra.utils.instantiate(cfg)
|
|
|
|
|
|
+ setup_device(
|
|
|
|
+ device=core_utils.get_param(cfg, "device"),
|
|
|
|
+ multi_gpu=core_utils.get_param(cfg, "multi_gpu"),
|
|
|
|
+ num_gpus=core_utils.get_param(cfg, "num_gpus"),
|
|
|
|
+ )
|
|
|
|
|
|
- kwargs = parse_args(cfg, cls.__init__)
|
|
|
|
|
|
+ # INSTANTIATE ALL OBJECTS IN CFG
|
|
|
|
+ cfg = hydra.utils.instantiate(cfg)
|
|
|
|
|
|
- trainer = KDTrainer(**kwargs)
|
|
|
|
|
|
+ trainer = KDTrainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)
|
|
|
|
|
|
# INSTANTIATE DATA LOADERS
|
|
# INSTANTIATE DATA LOADERS
|
|
train_dataloader = dataloaders.get(
|
|
train_dataloader = dataloaders.get(
|
|
@@ -80,6 +85,8 @@ class KDTrainer(Trainer):
|
|
load_backbone=cfg.teacher_checkpoint_params.load_backbone,
|
|
load_backbone=cfg.teacher_checkpoint_params.load_backbone,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
|
|
|
|
+
|
|
# TRAIN
|
|
# TRAIN
|
|
trainer.train(
|
|
trainer.train(
|
|
training_params=cfg.training_hyperparams,
|
|
training_params=cfg.training_hyperparams,
|
|
@@ -90,6 +97,7 @@ class KDTrainer(Trainer):
|
|
run_teacher_on_eval=cfg.run_teacher_on_eval,
|
|
run_teacher_on_eval=cfg.run_teacher_on_eval,
|
|
train_loader=train_dataloader,
|
|
train_loader=train_dataloader,
|
|
valid_loader=val_dataloader,
|
|
valid_loader=val_dataloader,
|
|
|
|
+ additional_configs_to_log=recipe_logged_cfg,
|
|
)
|
|
)
|
|
|
|
|
|
def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
|
|
def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs):
|
|
@@ -275,20 +283,22 @@ class KDTrainer(Trainer):
|
|
def train(
|
|
def train(
|
|
self,
|
|
self,
|
|
model: KDModule = None,
|
|
model: KDModule = None,
|
|
- training_params: dict = dict(),
|
|
|
|
|
|
+ training_params: Dict = None,
|
|
student: SgModule = None,
|
|
student: SgModule = None,
|
|
teacher: torch.nn.Module = None,
|
|
teacher: torch.nn.Module = None,
|
|
kd_architecture: Union[KDModule.__class__, str] = "kd_module",
|
|
kd_architecture: Union[KDModule.__class__, str] = "kd_module",
|
|
- kd_arch_params: dict = dict(),
|
|
|
|
|
|
+ kd_arch_params: Dict = None,
|
|
run_teacher_on_eval=False,
|
|
run_teacher_on_eval=False,
|
|
train_loader: DataLoader = None,
|
|
train_loader: DataLoader = None,
|
|
valid_loader: DataLoader = None,
|
|
valid_loader: DataLoader = None,
|
|
|
|
+ additional_configs_to_log: Dict = None,
|
|
*args,
|
|
*args,
|
|
**kwargs,
|
|
**kwargs,
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
Trains the student network (wrapped in KDModule network).
|
|
Trains the student network (wrapped in KDModule network).
|
|
|
|
|
|
|
|
+
|
|
:param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
|
|
:param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,
|
|
student and teacher (default=None)
|
|
student and teacher (default=None)
|
|
:param training_params: dict, Same as in Trainer.train()
|
|
:param training_params: dict, Same as in Trainer.train()
|
|
@@ -299,12 +309,21 @@ class KDTrainer(Trainer):
|
|
:param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
|
|
:param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
|
|
:param train_loader: Dataloader for train set.
|
|
:param train_loader: Dataloader for train set.
|
|
:param valid_loader: Dataloader for validation.
|
|
:param valid_loader: Dataloader for validation.
|
|
|
|
+ :param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training's
|
|
|
|
+ sg_logger. Format should be {"Config_title_1": {...}, "Config_title_2":{..}}, (optional, default=None)
|
|
"""
|
|
"""
|
|
kd_net = self.net or model
|
|
kd_net = self.net or model
|
|
|
|
+ kd_arch_params = kd_arch_params or dict()
|
|
if kd_net is None:
|
|
if kd_net is None:
|
|
if student is None or teacher is None:
|
|
if student is None or teacher is None:
|
|
raise ValueError("Must pass student and teacher models or net (KDModule).")
|
|
raise ValueError("Must pass student and teacher models or net (KDModule).")
|
|
kd_net = self._instantiate_kd_net(
|
|
kd_net = self._instantiate_kd_net(
|
|
arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
|
|
arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher
|
|
)
|
|
)
|
|
- super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader)
|
|
|
|
|
|
+ super(KDTrainer, self).train(
|
|
|
|
+ model=kd_net,
|
|
|
|
+ training_params=training_params,
|
|
|
|
+ train_loader=train_loader,
|
|
|
|
+ valid_loader=valid_loader,
|
|
|
|
+ additional_configs_to_log=additional_configs_to_log,
|
|
|
|
+ )
|