Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#211 SG-136: Apply ema only on student (KD)

Merged
Louis Dupont merged 1 commits into Deci-AI:master from deci-ai:feature/SG-136_use_ema_only_on_kd_student
@@ -6,7 +6,9 @@ from typing import Union
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
+from super_gradients.training import utils as core_utils
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
+from super_gradients.training.models.kd_modules.kd_module import KDModule
 
 
 
 
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
@@ -49,8 +51,8 @@ class ModelEMA:
             self.decay_function = lambda x: decay  # always return the same decay factor
             self.decay_function = lambda x: decay  # always return the same decay factor
 
 
         """"
         """"
-        we hold a list of model attributes (not wights and biases) which we would like to include in each 
-        attribute update or exclude from each update. a SgModule declare these attribute using 
+        we hold a list of model attributes (not wights and biases) which we would like to include in each
+        attribute update or exclude from each update. a SgModule declare these attribute using
         get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
         get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
         all non-private (not starting with '_') attributes will be updated (and only them).
         all non-private (not starting with '_') attributes will be updated (and only them).
         """
         """
@@ -89,3 +91,39 @@ class ModelEMA:
         :param model: the source model
         :param model: the source model
         """
         """
         copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
         copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
+
+
+class KDModelEMA(ModelEMA):
+    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
+    Keep a moving average of everything in the model state_dict (parameters and buffers).
+    This is intended to allow functionality like
+    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+    A smoothed version of the weights is necessary for some training schemes to perform well.
+    This class is sensitive where it is initialized in the sequence of model init,
+    GPU assignment and distributed training wrappers.
+    """
+
+    def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
+        """
+        Init the EMA
+        :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
+                    IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
+                    AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED (SEE
+                    YoLoV5Base IMPLEMENTATION IN super_gradients.trainer.models.yolov5.py AS AN EXAMPLE).
+        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
+                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
+                     its final value. beta=15 is ~40% of the training process.
+        """
+        # Only work on the student (we don't want to update and to have a duplicate of the teacher)
+        super().__init__(model=core_utils.WrappedModel(kd_model.module.student),
+                         decay=decay,
+                         beta=beta,
+                         exp_activation=exp_activation)
+
+        # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
+        # with already the instantiated teacher, to have the final KD EMA
+        self.ema = core_utils.WrappedModel(KDModule(arch_params=kd_model.module.arch_params,
+                                                    student=self.ema.module,
+                                                    teacher=kd_model.module.teacher,
+                                                    run_teacher_on_eval=kd_model.module.run_teacher_on_eval))
Discard
Tip!

Press p or to see the previous file or, n or to see the next file