Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#211 SG-136: Apply ema only on student (KD)

Merged
Louis Dupont merged 1 commits into Deci-AI:master from deci-ai:feature/SG-136_use_ema_only_on_kd_student
@@ -14,6 +14,7 @@ from super_gradients.training.exceptions.kd_model_exceptions import Architecture
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     TeacherKnowledgeException, UndefinedNumClassesException
     TeacherKnowledgeException, UndefinedNumClassesException
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
+from super_gradients.training.utils.ema import KDModelEMA
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -247,3 +248,15 @@ class KDModel(SgModel):
                                    "teacher_arch_params": self.teacher_arch_params
                                    "teacher_arch_params": self.teacher_arch_params
                                    })
                                    })
         return hyper_param_config
         return hyper_param_config
+
+    def instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
+        """Instantiate KD ema model for KDModule.
+
+        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
+        :param decay:           the maximum decay value. as the training process advances, the decay will climb towards
+                                this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta:            the exponent coefficient. The higher the beta, the sooner in the training the decay will
+                                saturate to its final value. beta=15 is ~40% of the training process.
+        :param exp_activation:
+        """
+        return KDModelEMA(self.net, decay, beta, exp_activation)
Discard
@@ -823,7 +823,7 @@ class SgModel:
         if self.ema:
         if self.ema:
             ema_params = self.training_params.ema_params
             ema_params = self.training_params.ema_params
             logger.info(f'Using EMA with params {ema_params}')
             logger.info(f'Using EMA with params {ema_params}')
-            self.ema_model = ModelEMA(self.net, **ema_params)
+            self.ema_model = self.instantiate_ema_model(**ema_params)
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             if self.load_checkpoint:
             if self.load_checkpoint:
                 if 'ema_net' in self.checkpoint.keys():
                 if 'ema_net' in self.checkpoint.keys():
@@ -1792,3 +1792,12 @@ class SgModel:
                 arch_params.num_classes = num_classes_new_head
                 arch_params.num_classes = num_classes_new_head
 
 
         return net
         return net
+
+    def instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA:
+        """Instantiate ema model for standard SgModule.
+        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
+                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
+                     its final value. beta=15 is ~40% of the training process.
+        """
+        return ModelEMA(self.net, decay, beta, exp_activation)
Discard
@@ -6,7 +6,9 @@ from typing import Union
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
+from super_gradients.training import utils as core_utils
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
+from super_gradients.training.models.kd_modules.kd_module import KDModule
 
 
 
 
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
@@ -49,8 +51,8 @@ class ModelEMA:
             self.decay_function = lambda x: decay  # always return the same decay factor
             self.decay_function = lambda x: decay  # always return the same decay factor
 
 
         """"
         """"
-        we hold a list of model attributes (not wights and biases) which we would like to include in each 
-        attribute update or exclude from each update. a SgModule declare these attribute using 
+        we hold a list of model attributes (not wights and biases) which we would like to include in each
+        attribute update or exclude from each update. a SgModule declare these attribute using
         get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
         get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
         all non-private (not starting with '_') attributes will be updated (and only them).
         all non-private (not starting with '_') attributes will be updated (and only them).
         """
         """
@@ -89,3 +91,39 @@ class ModelEMA:
         :param model: the source model
         :param model: the source model
         """
         """
         copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
         copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
+
+
+class KDModelEMA(ModelEMA):
+    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
+    Keep a moving average of everything in the model state_dict (parameters and buffers).
+    This is intended to allow functionality like
+    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+    A smoothed version of the weights is necessary for some training schemes to perform well.
+    This class is sensitive where it is initialized in the sequence of model init,
+    GPU assignment and distributed training wrappers.
+    """
+
+    def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
+        """
+        Init the EMA
+        :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
+                    IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
+                    AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED (SEE
+                    YoLoV5Base IMPLEMENTATION IN super_gradients.trainer.models.yolov5.py AS AN EXAMPLE).
+        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
+                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
+                     its final value. beta=15 is ~40% of the training process.
+        """
+        # Only work on the student (we don't want to update and to have a duplicate of the teacher)
+        super().__init__(model=core_utils.WrappedModel(kd_model.module.student),
+                         decay=decay,
+                         beta=beta,
+                         exp_activation=exp_activation)
+
+        # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
+        # with already the instantiated teacher, to have the final KD EMA
+        self.ema = core_utils.WrappedModel(KDModule(arch_params=kd_model.module.arch_params,
+                                                    student=self.ema.module,
+                                                    teacher=kd_model.module.teacher,
+                                                    run_teacher_on_eval=kd_model.module.run_teacher_on_eval))
Discard