Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#211 SG-136: Apply ema only on student (KD)

Merged
Louis Dupont merged 1 commits into Deci-AI:master from deci-ai:feature/SG-136_use_ema_only_on_kd_student
@@ -14,6 +14,7 @@ from super_gradients.training.exceptions.kd_model_exceptions import Architecture
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
     TeacherKnowledgeException, UndefinedNumClassesException
     TeacherKnowledgeException, UndefinedNumClassesException
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
+from super_gradients.training.utils.ema import KDModelEMA
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -247,3 +248,15 @@ class KDModel(SgModel):
                                    "teacher_arch_params": self.teacher_arch_params
                                    "teacher_arch_params": self.teacher_arch_params
                                    })
                                    })
         return hyper_param_config
         return hyper_param_config
+
+    def instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
+        """Instantiate KD ema model for KDModule.
+
+        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
+        :param decay:           the maximum decay value. as the training process advances, the decay will climb towards
+                                this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta:            the exponent coefficient. The higher the beta, the sooner in the training the decay will
+                                saturate to its final value. beta=15 is ~40% of the training process.
+        :param exp_activation:
+        """
+        return KDModelEMA(self.net, decay, beta, exp_activation)
Discard
@@ -823,7 +823,7 @@ class SgModel:
         if self.ema:
         if self.ema:
             ema_params = self.training_params.ema_params
             ema_params = self.training_params.ema_params
             logger.info(f'Using EMA with params {ema_params}')
             logger.info(f'Using EMA with params {ema_params}')
-            self.ema_model = ModelEMA(self.net, **ema_params)
+            self.ema_model = self.instantiate_ema_model(**ema_params)
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             if self.load_checkpoint:
             if self.load_checkpoint:
                 if 'ema_net' in self.checkpoint.keys():
                 if 'ema_net' in self.checkpoint.keys():
@@ -1792,3 +1792,12 @@ class SgModel:
                 arch_params.num_classes = num_classes_new_head
                 arch_params.num_classes = num_classes_new_head
 
 
         return net
         return net
+
+    def instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA:
+        """Instantiate ema model for standard SgModule.
+        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
+                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
+                     its final value. beta=15 is ~40% of the training process.
+        """
+        return ModelEMA(self.net, decay, beta, exp_activation)
Discard
@@ -6,7 +6,9 @@ from typing import Union
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
+from super_gradients.training import utils as core_utils
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
+from super_gradients.training.models.kd_modules.kd_module import KDModule
 
 
 
 
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
@@ -49,8 +51,8 @@ class ModelEMA:
             self.decay_function = lambda x: decay  # always return the same decay factor
             self.decay_function = lambda x: decay  # always return the same decay factor
 
 
         """"
         """"
-        we hold a list of model attributes (not wights and biases) which we would like to include in each 
-        attribute update or exclude from each update. a SgModule declare these attribute using 
+        we hold a list of model attributes (not wights and biases) which we would like to include in each
+        attribute update or exclude from each update. a SgModule declare these attribute using
         get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
         get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
         all non-private (not starting with '_') attributes will be updated (and only them).
         all non-private (not starting with '_') attributes will be updated (and only them).
         """
         """
@@ -89,3 +91,39 @@ class ModelEMA:
         :param model: the source model
         :param model: the source model
         """
         """
         copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
         copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
+
+
+class KDModelEMA(ModelEMA):
+    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
+    Keep a moving average of everything in the model state_dict (parameters and buffers).
+    This is intended to allow functionality like
+    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+    A smoothed version of the weights is necessary for some training schemes to perform well.
+    This class is sensitive where it is initialized in the sequence of model init,
+    GPU assignment and distributed training wrappers.
+    """
+
+    def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
+        """
+        Init the EMA
+        :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
+                    IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
+                    AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED (SEE
+                    YoLoV5Base IMPLEMENTATION IN super_gradients.trainer.models.yolov5.py AS AN EXAMPLE).
+        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
+                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
+        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
+                     its final value. beta=15 is ~40% of the training process.
+        """
+        # Only work on the student (we don't want to update and to have a duplicate of the teacher)
+        super().__init__(model=core_utils.WrappedModel(kd_model.module.student),
+                         decay=decay,
+                         beta=beta,
+                         exp_activation=exp_activation)
+
+        # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
+        # with already the instantiated teacher, to have the final KD EMA
+        self.ema = core_utils.WrappedModel(KDModule(arch_params=kd_model.module.arch_params,
+                                                    student=self.ema.module,
+                                                    teacher=kd_model.module.teacher,
+                                                    run_teacher_on_eval=kd_model.module.run_teacher_on_eval))
Discard
@@ -12,6 +12,7 @@ from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInitializedObjectsTest
 from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInitializedObjectsTest
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.lr_warmup_test import LRWarmupTest
 from tests.unit_tests.lr_warmup_test import LRWarmupTest
+from tests.unit_tests.kd_ema_test import KDEMATest
 from tests.unit_tests.kd_model_test import KDModelTest
 from tests.unit_tests.kd_model_test import KDModelTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.vit_unit_test import TestViT
 from tests.unit_tests.vit_unit_test import TestViT
@@ -54,6 +55,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(FactoriesTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(FactoriesTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DiceLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DiceLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestViT))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestViT))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDEMATest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDModelTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDModelTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLOX))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLOX))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  1. import unittest
  2. from super_gradients.training.sg_model import SgModel
  3. from super_gradients.training.kd_model.kd_model import KDModel
  4. import torch
  5. from super_gradients.training.utils.utils import check_models_have_same_weights
  6. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
  7. from super_gradients.training.metrics import Accuracy
  8. from super_gradients.training.losses.kd_losses import KDLogitsLoss
  9. class KDEMATest(unittest.TestCase):
  10. @classmethod
  11. def setUp(cls):
  12. cls.sg_trained_teacher = SgModel("sg_trained_teacher", device='cpu')
  13. cls.dataset_params = {"batch_size": 5}
  14. cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
  15. cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
  16. "lr_warmup_epochs": 0, "initial_lr": 0.1,
  17. "loss": KDLogitsLoss(torch.nn.CrossEntropyLoss()),
  18. "optimizer": "SGD",
  19. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  20. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  21. "metric_to_watch": "Accuracy",
  22. 'loss_logging_items_names': ["Loss", "Task Loss", "Distillation Loss"],
  23. "greater_metric_to_watch_is_better": True, "average_best_models": False,
  24. "ema": True}
  25. def test_teacher_ema_not_duplicated(self):
  26. """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
  27. kd_model = KDModel("test_teacher_ema_not_duplicated", device='cpu')
  28. kd_model.connect_dataset_interface(self.dataset)
  29. kd_model.build_model(student_architecture='resnet18',
  30. teacher_architecture='resnet50',
  31. student_arch_params={'num_classes': 1000},
  32. teacher_arch_params={'num_classes': 1000},
  33. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  34. run_teacher_on_eval=True, )
  35. kd_model.train(self.kd_train_params)
  36. self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher)
  37. self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student)
  38. def test_kd_ckpt_reload_ema(self):
  39. """Check that the KD model load correctly from checkpoint when "load_ema_as_net=True"."""
  40. # Create a KD model and train it
  41. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  42. kd_model.connect_dataset_interface(self.dataset)
  43. kd_model.build_model(student_architecture='resnet18',
  44. teacher_architecture='resnet50',
  45. student_arch_params={'num_classes': 1000},
  46. teacher_arch_params={'num_classes': 1000},
  47. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  48. run_teacher_on_eval=True, )
  49. kd_model.train(self.kd_train_params)
  50. ema_model = kd_model.ema_model.ema
  51. net = kd_model.net
  52. # Load the trained KD model
  53. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  54. kd_model.connect_dataset_interface(self.dataset)
  55. kd_model.build_model(student_architecture='resnet18',
  56. teacher_architecture='resnet50',
  57. student_arch_params={'num_classes': 1000},
  58. teacher_arch_params={'num_classes': 1000},
  59. checkpoint_params={"load_checkpoint": True, "load_ema_as_net": True},
  60. run_teacher_on_eval=True, )
  61. # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
  62. kd_model.train(self.kd_train_params)
  63. reloaded_ema_model = kd_model.ema_model.ema
  64. reloaded_net = kd_model.net
  65. # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
  66. self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
  67. # loaded net != trained net (since load_ema_as_net = True)
  68. self.assertTrue(not check_models_have_same_weights(reloaded_net, net))
  69. # loaded net == trained ema (since load_ema_as_net = True)
  70. self.assertTrue(check_models_have_same_weights(reloaded_net, ema_model))
  71. # loaded student ema == loaded student net (since load_ema_as_net = True)
  72. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
  73. # loaded teacher ema == loaded teacher net (teacher always loads ema)
  74. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
  75. def test_kd_ckpt_reload_net(self):
  76. """Check that the KD model load correctly from checkpoint when "load_ema_as_net=False"."""
  77. # Create a KD model and train it
  78. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  79. kd_model.connect_dataset_interface(self.dataset)
  80. kd_model.build_model(student_architecture='resnet18',
  81. teacher_architecture='resnet50',
  82. student_arch_params={'num_classes': 1000},
  83. teacher_arch_params={'num_classes': 1000},
  84. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  85. run_teacher_on_eval=True, )
  86. kd_model.train(self.kd_train_params)
  87. ema_model = kd_model.ema_model.ema
  88. net = kd_model.net
  89. # Load the trained KD model
  90. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  91. kd_model.connect_dataset_interface(self.dataset)
  92. kd_model.build_model(student_architecture='resnet18',
  93. teacher_architecture='resnet50',
  94. student_arch_params={'num_classes': 1000},
  95. teacher_arch_params={'num_classes': 1000},
  96. checkpoint_params={"load_checkpoint": True, "load_ema_as_net": False},
  97. run_teacher_on_eval=True, )
  98. # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
  99. kd_model.train(self.kd_train_params)
  100. reloaded_ema_model = kd_model.ema_model.ema
  101. reloaded_net = kd_model.net
  102. # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
  103. self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
  104. # loaded net == trained net (since load_ema_as_net = False)
  105. self.assertTrue(check_models_have_same_weights(reloaded_net, net))
  106. # loaded net != trained ema (since load_ema_as_net = False)
  107. self.assertTrue(not check_models_have_same_weights(reloaded_net, ema_model))
  108. # loaded student ema == loaded student net (since load_ema_as_net = False)
  109. self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
  110. # loaded teacher ema == loaded teacher net (teacher always loads ema)
  111. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
  112. if __name__ == '__main__':
  113. unittest.main()
Discard