Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#211 SG-136: Apply ema only on student (KD)

Merged
Louis Dupont merged 1 commits into Deci-AI:master from deci-ai:feature/SG-136_use_ema_only_on_kd_student
@@ -823,7 +823,7 @@ class SgModel:
         if self.ema:
         if self.ema:
             ema_params = self.training_params.ema_params
             ema_params = self.training_params.ema_params
             logger.info(f'Using EMA with params {ema_params}')
             logger.info(f'Using EMA with params {ema_params}')
-            self.ema_model = ModelEMA(self.net, **ema_params)
+            self.ema_model = self.instantiate_ema_model(**ema_params)
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             if self.load_checkpoint:
             if self.load_checkpoint:
                 if 'ema_net' in self.checkpoint.keys():
                 if 'ema_net' in self.checkpoint.keys():
@@ -1792,3 +1792,12 @@ class SgModel:
                 arch_params.num_classes = num_classes_new_head
                 arch_params.num_classes = num_classes_new_head
 
 
         return net
         return net
+
+    def instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA:
+        """Instantiate ema model for standard SgModule.
+        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
+                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
+                     its final value. beta=15 is ~40% of the training process.
+        """
+        return ModelEMA(self.net, decay, beta, exp_activation)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file