|
@@ -133,6 +133,8 @@ class KDTrainerTest(unittest.TestCase):
|
|
train_params = self.kd_train_params.copy()
|
|
train_params = self.kd_train_params.copy()
|
|
train_params["max_epochs"] = 1
|
|
train_params["max_epochs"] = 1
|
|
train_params["ema"] = True
|
|
train_params["ema"] = True
|
|
|
|
+ train_params["ema_params"] = {"decay_type": "constant", "decay": 0.999}
|
|
|
|
+
|
|
kd_trainer.train(
|
|
kd_trainer.train(
|
|
training_params=train_params,
|
|
training_params=train_params,
|
|
student=student,
|
|
student=student,
|