|
@@ -6,7 +6,9 @@ from typing import Union
|
|
import torch
|
|
import torch
|
|
from torch import nn
|
|
from torch import nn
|
|
|
|
|
|
|
|
+from super_gradients.training import utils as core_utils
|
|
from super_gradients.training.models import SgModule
|
|
from super_gradients.training.models import SgModule
|
|
|
|
+from super_gradients.training.models.kd_modules.kd_module import KDModule
|
|
|
|
|
|
|
|
|
|
def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
|
|
def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
|
|
@@ -49,8 +51,8 @@ class ModelEMA:
|
|
self.decay_function = lambda x: decay # always return the same decay factor
|
|
self.decay_function = lambda x: decay # always return the same decay factor
|
|
|
|
|
|
""""
|
|
""""
|
|
- we hold a list of model attributes (not wights and biases) which we would like to include in each
|
|
|
|
- attribute update or exclude from each update. a SgModule declare these attribute using
|
|
|
|
|
|
+ we hold a list of model attributes (not wights and biases) which we would like to include in each
|
|
|
|
+ attribute update or exclude from each update. a SgModule declare these attribute using
|
|
get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
|
|
get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
|
|
all non-private (not starting with '_') attributes will be updated (and only them).
|
|
all non-private (not starting with '_') attributes will be updated (and only them).
|
|
"""
|
|
"""
|
|
@@ -89,3 +91,39 @@ class ModelEMA:
|
|
:param model: the source model
|
|
:param model: the source model
|
|
"""
|
|
"""
|
|
copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
|
|
copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class KDModelEMA(ModelEMA):
|
|
|
|
+ """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
|
|
|
+ Keep a moving average of everything in the model state_dict (parameters and buffers).
|
|
|
|
+ This is intended to allow functionality like
|
|
|
|
+ https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
|
|
+ A smoothed version of the weights is necessary for some training schemes to perform well.
|
|
|
|
+ This class is sensitive where it is initialized in the sequence of model init,
|
|
|
|
+ GPU assignment and distributed training wrappers.
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
|
|
|
|
+ """
|
|
|
|
+ Init the EMA
|
|
|
|
+ :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
|
|
|
|
+ IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
|
|
|
|
+ AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED (SEE
|
|
|
|
+ YoLoV5Base IMPLEMENTATION IN super_gradients.trainer.models.yolov5.py AS AN EXAMPLE).
|
|
|
|
+ :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
|
|
|
|
+ until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
|
|
|
|
+ :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
|
|
|
|
+ its final value. beta=15 is ~40% of the training process.
|
|
|
|
+ """
|
|
|
|
+ # Only work on the student (we don't want to update and to have a duplicate of the teacher)
|
|
|
|
+ super().__init__(model=core_utils.WrappedModel(kd_model.module.student),
|
|
|
|
+ decay=decay,
|
|
|
|
+ beta=beta,
|
|
|
|
+ exp_activation=exp_activation)
|
|
|
|
+
|
|
|
|
+ # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
|
|
|
|
+ # with already the instantiated teacher, to have the final KD EMA
|
|
|
|
+ self.ema = core_utils.WrappedModel(KDModule(arch_params=kd_model.module.arch_params,
|
|
|
|
+ student=self.ema.module,
|
|
|
|
+ teacher=kd_model.module.teacher,
|
|
|
|
+ run_teacher_on_eval=kd_model.module.run_teacher_on_eval))
|