Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
1 changed files with 33 additions and 5 deletions
  1. 33
    5
      src/super_gradients/training/models/model_factory.py
@@ -1,10 +1,18 @@
+import hydra
+
 from super_gradients.common import StrictLoad
 from super_gradients.common import StrictLoad
+from super_gradients.common.plugins.deci_client import DeciClient
 from super_gradients.training import utils as core_utils
 from super_gradients.training import utils as core_utils
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model, load_pretrained_weights, \
-    read_ckpt_state_dict
+from super_gradients.training.utils import HpmStruct
+from super_gradients.training.utils.checkpoint_utils import (
+    load_checkpoint_to_model,
+    load_pretrained_weights,
+    read_ckpt_state_dict,
+    load_pretrained_weights_local,
+)
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -31,14 +39,34 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No
 
 
         arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
         arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
 
 
+    remote_model = False
     if isinstance(name, str) and name in ARCHITECTURES.keys():
     if isinstance(name, str) and name in ARCHITECTURES.keys():
         architecture_cls = ARCHITECTURES[name]
         architecture_cls = ARCHITECTURES[name]
         net = architecture_cls(arch_params=arch_params)
         net = architecture_cls(arch_params=arch_params)
+    elif isinstance(name, str):
+        logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
+        deci_client = DeciClient()
+        _arch_params = deci_client.get_model_arch_params(name)
+
+        if _arch_params is not None:
+            _arch_params = hydra.utils.instantiate(_arch_params)
+            base_name = _arch_params["model_name"]
+            _arch_params = HpmStruct(**_arch_params)
+            architecture_cls = ARCHITECTURES[base_name]
+            _arch_params.override(**arch_params.to_dict())
+
+            net = architecture_cls(arch_params=_arch_params)
+            remote_model = True
+        else:
+            raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
     else:
     else:
-        raise ValueError(
-            "Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
+        raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
     if pretrained_weights:
     if pretrained_weights:
-        load_pretrained_weights(net, name, pretrained_weights)
+        if remote_model:
+            weights_path = deci_client.get_model_weights(name)
+            load_pretrained_weights_local(net, name, weights_path)
+        else:
+            load_pretrained_weights(net, name, pretrained_weights)
         if num_classes_new_head != arch_params.num_classes:
         if num_classes_new_head != arch_params.num_classes:
             net.replace_head(new_num_classes=num_classes_new_head)
             net.replace_head(new_num_classes=num_classes_new_head)
             arch_params.num_classes = num_classes_new_head
             arch_params.num_classes = num_classes_new_head
Discard