|
@@ -29,11 +29,21 @@ class EMAIntegrationTest(unittest.TestCase):
|
|
def tearDownClass(cls) -> None:
|
|
def tearDownClass(cls) -> None:
|
|
pass
|
|
pass
|
|
|
|
|
|
- def test_train(self):
|
|
|
|
|
|
+ def test_train_exp_decay(self):
|
|
self._init_model()
|
|
self._init_model()
|
|
- self._train({})
|
|
|
|
|
|
+ self._train({"decay_type": "exp", "beta": 15, "decay": 0.9999})
|
|
|
|
+
|
|
|
|
+ def test_train_threshold_decay(self):
|
|
|
|
+ self._init_model()
|
|
|
|
+ self._train({"decay_type": "threshold", "decay": 0.9999})
|
|
|
|
+
|
|
|
|
+ def test_train_constant_decay(self):
|
|
|
|
+ self._init_model()
|
|
|
|
+ self._train({"decay_type": "constant", "decay": 0.9999})
|
|
|
|
+
|
|
|
|
+ def test_train_with_old_ema_params(self):
|
|
self._init_model()
|
|
self._init_model()
|
|
- self._train({"exp_activation": False})
|
|
|
|
|
|
+ self._train({"decay": 0.9999, "exp_activation": True, "beta": 10})
|
|
|
|
|
|
def _train(self, ema_params):
|
|
def _train(self, ema_params):
|
|
training_params = {
|
|
training_params = {
|