Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#647 Feature/sg 573 Integrate new EMA decay schedules

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-573-Integrate-EMA
2 changed files with 3 additions and 0 deletions
  1. 1
    0
      tests/unit_tests/kd_ema_test.py
  2. 2
    0
      tests/unit_tests/kd_trainer_test.py
@@ -34,6 +34,7 @@ class KDEMATest(unittest.TestCase):
             "greater_metric_to_watch_is_better": True,
             "greater_metric_to_watch_is_better": True,
             "average_best_models": False,
             "average_best_models": False,
             "ema": True,
             "ema": True,
+            "ema_params": {"decay_type": "constant", "decay": 0.999},
         }
         }
 
 
     def test_teacher_ema_not_duplicated(self):
     def test_teacher_ema_not_duplicated(self):
Discard
@@ -133,6 +133,8 @@ class KDTrainerTest(unittest.TestCase):
         train_params = self.kd_train_params.copy()
         train_params = self.kd_train_params.copy()
         train_params["max_epochs"] = 1
         train_params["max_epochs"] = 1
         train_params["ema"] = True
         train_params["ema"] = True
+        train_params["ema_params"] = {"decay_type": "constant", "decay": 0.999}
+
         kd_trainer.train(
         kd_trainer.train(
             training_params=train_params,
             training_params=train_params,
             student=student,
             student=student,
Discard