Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#211 SG-136: Apply ema only on student (KD)

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:feature/SG-136_use_ema_only_on_kd_student
2 changed files with 142 additions and 0 deletions
  1. 2
    0
      tests/deci_core_unit_test_suite_runner.py
  2. 140
    0
      tests/unit_tests/kd_ema_test.py
@@ -12,6 +12,7 @@ from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInitializedObjectsTest
 from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInitializedObjectsTest
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTest
 from tests.unit_tests.lr_warmup_test import LRWarmupTest
 from tests.unit_tests.lr_warmup_test import LRWarmupTest
+from tests.unit_tests.kd_ema_test import KDEMATest
 from tests.unit_tests.kd_model_test import KDModelTest
 from tests.unit_tests.kd_model_test import KDModelTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.vit_unit_test import TestViT
 from tests.unit_tests.vit_unit_test import TestViT
@@ -54,6 +55,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(FactoriesTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(FactoriesTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DiceLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DiceLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestViT))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestViT))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDEMATest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDModelTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDModelTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLOX))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLOX))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  1. import unittest
  2. from super_gradients.training.sg_model import SgModel
  3. from super_gradients.training.kd_model.kd_model import KDModel
  4. import torch
  5. from super_gradients.training.utils.utils import check_models_have_same_weights
  6. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
  7. from super_gradients.training.metrics import Accuracy
  8. from super_gradients.training.losses.kd_losses import KDLogitsLoss
  9. class KDEMATest(unittest.TestCase):
  10. @classmethod
  11. def setUp(cls):
  12. cls.sg_trained_teacher = SgModel("sg_trained_teacher", device='cpu')
  13. cls.dataset_params = {"batch_size": 5}
  14. cls.dataset = ClassificationTestDatasetInterface(dataset_params=cls.dataset_params)
  15. cls.kd_train_params = {"max_epochs": 3, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
  16. "lr_warmup_epochs": 0, "initial_lr": 0.1,
  17. "loss": KDLogitsLoss(torch.nn.CrossEntropyLoss()),
  18. "optimizer": "SGD",
  19. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  20. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  21. "metric_to_watch": "Accuracy",
  22. 'loss_logging_items_names': ["Loss", "Task Loss", "Distillation Loss"],
  23. "greater_metric_to_watch_is_better": True, "average_best_models": False,
  24. "ema": True}
  25. def test_teacher_ema_not_duplicated(self):
  26. """Check that the teacher EMA is a reference to the teacher net (not a copy)."""
  27. kd_model = KDModel("test_teacher_ema_not_duplicated", device='cpu')
  28. kd_model.connect_dataset_interface(self.dataset)
  29. kd_model.build_model(student_architecture='resnet18',
  30. teacher_architecture='resnet50',
  31. student_arch_params={'num_classes': 1000},
  32. teacher_arch_params={'num_classes': 1000},
  33. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  34. run_teacher_on_eval=True, )
  35. kd_model.train(self.kd_train_params)
  36. self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher)
  37. self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student)
  38. def test_kd_ckpt_reload_ema(self):
  39. """Check that the KD model load correctly from checkpoint when "load_ema_as_net=True"."""
  40. # Create a KD model and train it
  41. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  42. kd_model.connect_dataset_interface(self.dataset)
  43. kd_model.build_model(student_architecture='resnet18',
  44. teacher_architecture='resnet50',
  45. student_arch_params={'num_classes': 1000},
  46. teacher_arch_params={'num_classes': 1000},
  47. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  48. run_teacher_on_eval=True, )
  49. kd_model.train(self.kd_train_params)
  50. ema_model = kd_model.ema_model.ema
  51. net = kd_model.net
  52. # Load the trained KD model
  53. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  54. kd_model.connect_dataset_interface(self.dataset)
  55. kd_model.build_model(student_architecture='resnet18',
  56. teacher_architecture='resnet50',
  57. student_arch_params={'num_classes': 1000},
  58. teacher_arch_params={'num_classes': 1000},
  59. checkpoint_params={"load_checkpoint": True, "load_ema_as_net": True},
  60. run_teacher_on_eval=True, )
  61. # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
  62. kd_model.train(self.kd_train_params)
  63. reloaded_ema_model = kd_model.ema_model.ema
  64. reloaded_net = kd_model.net
  65. # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
  66. self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
  67. # loaded net != trained net (since load_ema_as_net = True)
  68. self.assertTrue(not check_models_have_same_weights(reloaded_net, net))
  69. # loaded net == trained ema (since load_ema_as_net = True)
  70. self.assertTrue(check_models_have_same_weights(reloaded_net, ema_model))
  71. # loaded student ema == loaded student net (since load_ema_as_net = True)
  72. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
  73. # loaded teacher ema == loaded teacher net (teacher always loads ema)
  74. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
  75. def test_kd_ckpt_reload_net(self):
  76. """Check that the KD model load correctly from checkpoint when "load_ema_as_net=False"."""
  77. # Create a KD model and train it
  78. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  79. kd_model.connect_dataset_interface(self.dataset)
  80. kd_model.build_model(student_architecture='resnet18',
  81. teacher_architecture='resnet50',
  82. student_arch_params={'num_classes': 1000},
  83. teacher_arch_params={'num_classes': 1000},
  84. checkpoint_params={'teacher_pretrained_weights': "imagenet"},
  85. run_teacher_on_eval=True, )
  86. kd_model.train(self.kd_train_params)
  87. ema_model = kd_model.ema_model.ema
  88. net = kd_model.net
  89. # Load the trained KD model
  90. kd_model = KDModel("test_kd_ema_ckpt_reload", device='cpu')
  91. kd_model.connect_dataset_interface(self.dataset)
  92. kd_model.build_model(student_architecture='resnet18',
  93. teacher_architecture='resnet50',
  94. student_arch_params={'num_classes': 1000},
  95. teacher_arch_params={'num_classes': 1000},
  96. checkpoint_params={"load_checkpoint": True, "load_ema_as_net": False},
  97. run_teacher_on_eval=True, )
  98. # TRAIN FOR 0 EPOCHS JUST TO SEE THAT WHEN CONTINUING TRAINING EMA MODEL HAS BEEN SAVED CORRECTLY
  99. kd_model.train(self.kd_train_params)
  100. reloaded_ema_model = kd_model.ema_model.ema
  101. reloaded_net = kd_model.net
  102. # trained ema == loaded ema (Should always be true as long as "ema=True" in train_params)
  103. self.assertTrue(check_models_have_same_weights(ema_model, reloaded_ema_model))
  104. # loaded net == trained net (since load_ema_as_net = False)
  105. self.assertTrue(check_models_have_same_weights(reloaded_net, net))
  106. # loaded net != trained ema (since load_ema_as_net = False)
  107. self.assertTrue(not check_models_have_same_weights(reloaded_net, ema_model))
  108. # loaded student ema == loaded student net (since load_ema_as_net = False)
  109. self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
  110. # loaded teacher ema == loaded teacher net (teacher always loads ema)
  111. self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
  112. if __name__ == '__main__':
  113. unittest.main()
Discard