Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#647 Feature/sg 573 Integrate new EMA decay schedules

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-573-Integrate-EMA
@@ -29,11 +29,21 @@ class EMAIntegrationTest(unittest.TestCase):
     def tearDownClass(cls) -> None:
     def tearDownClass(cls) -> None:
         pass
         pass
 
 
-    def test_train(self):
+    def test_train_exp_decay(self):
         self._init_model()
         self._init_model()
-        self._train({})
+        self._train({"decay_type": "exp", "beta": 15, "decay": 0.9999})
+
+    def test_train_threshold_decay(self):
+        self._init_model()
+        self._train({"decay_type": "threshold", "decay": 0.9999})
+
+    def test_train_constant_decay(self):
+        self._init_model()
+        self._train({"decay_type": "constant", "decay": 0.9999})
+
+    def test_train_with_old_ema_params(self):
         self._init_model()
         self._init_model()
-        self._train({"exp_activation": False})
+        self._train({"decay": 0.9999, "exp_activation": True, "beta": 10})
 
 
     def _train(self, ema_params):
     def _train(self, ema_params):
         training_params = {
         training_params = {
Discard
@@ -34,6 +34,7 @@ class KDEMATest(unittest.TestCase):
             "greater_metric_to_watch_is_better": True,
             "greater_metric_to_watch_is_better": True,
             "average_best_models": False,
             "average_best_models": False,
             "ema": True,
             "ema": True,
+            "ema_params": {"decay_type": "constant", "decay": 0.999},
         }
         }
 
 
     def test_teacher_ema_not_duplicated(self):
     def test_teacher_ema_not_duplicated(self):
Discard
@@ -133,6 +133,8 @@ class KDTrainerTest(unittest.TestCase):
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         train_params["ema"] = True
         train_params["ema"] = True
+        train_params["ema_params"] = {"decay_type": "constant", "decay": 0.999}
+
         kd_trainer.train(
         kd_trainer.train(
             training_params=train_params,
             training_params=train_params,
             student=student,
             student=student,
Discard