Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#578 Feature/sg 516 support head replacement for local pretrained weights unknown dataset

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-516_support_head_replacement_for_local_pretrained_weights_unknown_dataset
1 changed files with 14 additions and 1 deletions
  1. 14
    1
      src/super_gradients/training/models/model_factory.py
@@ -139,6 +139,7 @@ def get(
     pretrained_weights: str = None,
     pretrained_weights: str = None,
     load_backbone: bool = False,
     load_backbone: bool = False,
     download_required_code: bool = True,
     download_required_code: bool = True,
+    checkpoint_num_classes: int = None,
 ) -> SgModule:
 ) -> SgModule:
     """
     """
     :param model_name:          Defines the model's architecture from models/ALL_ARCHITECTURES
     :param model_name:          Defines the model's architecture from models/ALL_ARCHITECTURES
@@ -153,11 +154,20 @@ def get(
     :param load_backbone:       Load the provided checkpoint to model.backbone instead of model.
     :param load_backbone:       Load the provided checkpoint to model.backbone instead of model.
     :param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
     :param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
                                     will prevent additional code from being downloaded. This affects only models from remote client.
                                     will prevent additional code from being downloaded. This affects only models from remote client.
+    :param checkpoint_num_classes:  num_classes of checkpoint_path/ pretrained_weights, when checkpoint_path is not None.
+     Used when num_classes != checkpoint_num_class. In this case, the module will be initialized with checkpoint_num_class, then weights will be loaded. Finaly
+        replace_head(new_num_classes=num_classes) is called (useful when wanting to perform transfer learning, from a checkpoint outside of
+         then ones offered in SG model zoo).
+
 
 
     NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
     NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
     """
     """
+    checkpoint_num_classes = checkpoint_num_classes or num_classes
 
 
-    net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights, download_required_code)
+    if checkpoint_num_classes:
+        net = instantiate_model(model_name, arch_params, checkpoint_num_classes, pretrained_weights, download_required_code)
+    else:
+        net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights, download_required_code)
 
 
     if load_backbone and not checkpoint_path:
     if load_backbone and not checkpoint_path:
         raise ValueError("Please set checkpoint_path when load_backbone=True")
         raise ValueError("Please set checkpoint_path when load_backbone=True")
@@ -172,4 +182,7 @@ def get(
             load_weights_only=True,
             load_weights_only=True,
             load_ema_as_net=load_ema_as_net,
             load_ema_as_net=load_ema_as_net,
         )
         )
+    if checkpoint_num_classes != num_classes:
+        net.replace_head(new_num_classes=num_classes)
+
     return net
     return net
Discard