Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#578 Feature/sg 516 support head replacement for local pretrained weights unknown dataset

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-516_support_head_replacement_for_local_pretrained_weights_unknown_dataset
@@ -139,6 +139,7 @@ def get(
     pretrained_weights: str = None,
     pretrained_weights: str = None,
     load_backbone: bool = False,
     load_backbone: bool = False,
     download_required_code: bool = True,
     download_required_code: bool = True,
+    checkpoint_num_classes: int = None,
 ) -> SgModule:
 ) -> SgModule:
     """
     """
     :param model_name:          Defines the model's architecture from models/ALL_ARCHITECTURES
     :param model_name:          Defines the model's architecture from models/ALL_ARCHITECTURES
@@ -153,11 +154,20 @@ def get(
     :param load_backbone:       Load the provided checkpoint to model.backbone instead of model.
     :param load_backbone:       Load the provided checkpoint to model.backbone instead of model.
     :param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
     :param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
                                     will prevent additional code from being downloaded. This affects only models from remote client.
                                     will prevent additional code from being downloaded. This affects only models from remote client.
+    :param checkpoint_num_classes:  num_classes of checkpoint_path/ pretrained_weights, when checkpoint_path is not None.
+     Used when num_classes != checkpoint_num_class. In this case, the module will be initialized with checkpoint_num_class, then weights will be loaded. Finaly
+        replace_head(new_num_classes=num_classes) is called (useful when wanting to perform transfer learning, from a checkpoint outside of
+         then ones offered in SG model zoo).
+
 
 
     NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
     NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
     """
     """
+    checkpoint_num_classes = checkpoint_num_classes or num_classes
 
 
-    net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights, download_required_code)
+    if checkpoint_num_classes:
+        net = instantiate_model(model_name, arch_params, checkpoint_num_classes, pretrained_weights, download_required_code)
+    else:
+        net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights, download_required_code)
 
 
     if load_backbone and not checkpoint_path:
     if load_backbone and not checkpoint_path:
         raise ValueError("Please set checkpoint_path when load_backbone=True")
         raise ValueError("Please set checkpoint_path when load_backbone=True")
@@ -172,4 +182,7 @@ def get(
             load_weights_only=True,
             load_weights_only=True,
             load_ema_as_net=load_ema_as_net,
             load_ema_as_net=load_ema_as_net,
         )
         )
+    if checkpoint_num_classes != num_classes:
+        net.replace_head(new_num_classes=num_classes)
+
     return net
     return net
Discard
@@ -22,6 +22,7 @@ from tests.unit_tests import (
     CrashTipTest,
     CrashTipTest,
 )
 )
 from tests.end_to_end_tests import TestTrainer
 from tests.end_to_end_tests import TestTrainer
+from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
@@ -107,6 +108,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(CallTrainAfterTestTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(CallTrainAfterTestTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionOutputAdapter))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionOutputAdapter))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ConfigInspectTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ConfigInspectTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
  1. import unittest
  2. from super_gradients.training import Trainer, models
  3. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  4. from super_gradients.training.metrics import Accuracy, Top5
  5. from super_gradients.training.utils.utils import check_models_have_same_weights
  6. import os
  7. class LocalCkptHeadReplacementTest(unittest.TestCase):
  8. def test_local_ckpt_head_replacement(self):
  9. train_params = {
  10. "max_epochs": 1,
  11. "lr_updates": [1],
  12. "lr_decay_factor": 0.1,
  13. "lr_mode": "step",
  14. "lr_warmup_epochs": 0,
  15. "initial_lr": 0.1,
  16. "loss": "cross_entropy",
  17. "optimizer": "SGD",
  18. "criterion_params": {},
  19. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  20. "train_metrics_list": [Accuracy(), Top5()],
  21. "valid_metrics_list": [Accuracy(), Top5()],
  22. "metric_to_watch": "Accuracy",
  23. "greater_metric_to_watch_is_better": True,
  24. }
  25. # Define Model
  26. net = models.get("resnet18", num_classes=5)
  27. trainer = Trainer("test_resume_training")
  28. trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  29. ckpt_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
  30. net2 = models.get("resnet18", num_classes=10, checkpoint_num_classes=5, checkpoint_path=ckpt_path)
  31. self.assertFalse(check_models_have_same_weights(net, net2))
  32. net.linear = None
  33. net2.linear = None
  34. self.assertTrue(check_models_have_same_weights(net, net2))
Discard