Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
@@ -1,10 +1,18 @@
+import hydra
+
 from super_gradients.common import StrictLoad
 from super_gradients.common import StrictLoad
+from super_gradients.common.plugins.deci_client import DeciClient
 from super_gradients.training import utils as core_utils
 from super_gradients.training import utils as core_utils
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model, load_pretrained_weights, \
-    read_ckpt_state_dict
+from super_gradients.training.utils import HpmStruct
+from super_gradients.training.utils.checkpoint_utils import (
+    load_checkpoint_to_model,
+    load_pretrained_weights,
+    read_ckpt_state_dict,
+    load_pretrained_weights_local,
+)
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -31,14 +39,34 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No
 
 
         arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
         arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
 
 
+    remote_model = False
     if isinstance(name, str) and name in ARCHITECTURES.keys():
     if isinstance(name, str) and name in ARCHITECTURES.keys():
         architecture_cls = ARCHITECTURES[name]
         architecture_cls = ARCHITECTURES[name]
         net = architecture_cls(arch_params=arch_params)
         net = architecture_cls(arch_params=arch_params)
+    elif isinstance(name, str):
+        logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
+        deci_client = DeciClient()
+        _arch_params = deci_client.get_model_arch_params(name)
+
+        if _arch_params is not None:
+            _arch_params = hydra.utils.instantiate(_arch_params)
+            base_name = _arch_params["model_name"]
+            _arch_params = HpmStruct(**_arch_params)
+            architecture_cls = ARCHITECTURES[base_name]
+            _arch_params.override(**arch_params.to_dict())
+
+            net = architecture_cls(arch_params=_arch_params)
+            remote_model = True
+        else:
+            raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
     else:
     else:
-        raise ValueError(
-            "Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
+        raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
     if pretrained_weights:
     if pretrained_weights:
-        load_pretrained_weights(net, name, pretrained_weights)
+        if remote_model:
+            weights_path = deci_client.get_model_weights(name)
+            load_pretrained_weights_local(net, name, weights_path)
+        else:
+            load_pretrained_weights(net, name, pretrained_weights)
         if num_classes_new_head != arch_params.num_classes:
         if num_classes_new_head != arch_params.num_classes:
             net.replace_head(new_num_classes=num_classes_new_head)
             net.replace_head(new_num_classes=num_classes_new_head)
             arch_params.num_classes = num_classes_new_head
             arch_params.num_classes = num_classes_new_head
Discard
@@ -265,6 +265,10 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
     unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace('/', '_').replace(' ', '_')
     unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace('/', '_').replace(' ', '_')
     map_location = torch.device('cpu')
     map_location = torch.device('cpu')
     pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
     pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
+    _load_weights(architecture, model, pretrained_state_dict)
+
+
+def _load_weights(architecture, model, pretrained_state_dict):
     if 'ema_net' in pretrained_state_dict.keys():
     if 'ema_net' in pretrained_state_dict.keys():
         pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
         pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
     solver = _yolox_ckpt_solver if "yolox" in architecture else None
     solver = _yolox_ckpt_solver if "yolox" in architecture else None
@@ -272,3 +276,19 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
                                                                               source_ckpt=pretrained_state_dict,
                                                                               source_ckpt=pretrained_state_dict,
                                                                               solver=solver)
                                                                               solver=solver)
     model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
     model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
+
+
+def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
+
+    """
+    Loads pretrained weights from the MODEL_URLS dictionary to model
+    @param architecture: name of the model's architecture
+    @param model: model to load pretrinaed weights for
+    @param pretrained_weights: path tp pretrained weights
+    @return: None
+    """
+
+    map_location = torch.device('cpu')
+
+    pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
+    _load_weights(architecture, model, pretrained_state_dict)
Discard