Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
@@ -24,7 +24,6 @@ pyparsing==2.4.5
 einops==0.3.2
 einops==0.3.2
 pycocotools==2.0.4
 pycocotools==2.0.4
 protobuf~=3.19.0
 protobuf~=3.19.0
-deci-lab-client==2.38.0
 treelib==1.6.1
 treelib==1.6.1
 termcolor==1.1.0
 termcolor==1.1.0
 packaging>=20.4
 packaging>=20.4
Discard
@@ -1,4 +1,5 @@
 import argparse
 import argparse
+import importlib
 import os
 import os
 import sys
 import sys
 import socket
 import socket
@@ -33,6 +34,18 @@ class ColouredTextFormatter:
         return print(''.join([colour, text, TerminalColours.ENDC]))
         return print(''.join([colour, text, TerminalColours.ENDC]))
 
 
 
 
+def get_cls(cls_path):
+    """
+    A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.
+    usage:
+    class_of_optimizer: ${class:torch.optim.Adam}
+    """
+    module = '.'.join(cls_path.split('.')[:-1])
+    name = cls_path.split('.')[-1]
+    importlib.import_module(module)
+    return getattr(sys.modules[module], name)
+
+
 def get_environ_as_type(environment_variable_name: str, default=None, cast_to_type: type = str) -> object:
 def get_environ_as_type(environment_variable_name: str, default=None, cast_to_type: type = str) -> object:
     """
     """
     Tries to get an environment variable and cast it into a requested type.
     Tries to get an environment variable and cast it into a requested type.
@@ -65,19 +78,22 @@ def init_trainer():
     This function should be the first thing to be called by any code running super_gradients.
     This function should be the first thing to be called by any code running super_gradients.
     It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
     It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
     """
     """
+    if not environment_config.INIT_TRAINER:
 
 
-    register_hydra_resolvers()
+        register_hydra_resolvers()
 
 
-    # We pop local_rank if it was specified in the args, because it would break
-    args_local_rank = pop_arg("local_rank", default_value=-1)
+        # We pop local_rank if it was specified in the args, because it would break
+        args_local_rank = pop_arg("local_rank", default_value=-1)
 
 
-    # Set local_rank with priority order (env variable > args.local_rank > args.default_value)
-    environment_config.DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
+        # Set local_rank with priority order (env variable > args.local_rank > args.default_value)
+        environment_config.DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
+        environment_config.INIT_TRAINER = True
 
 
 
 
 def register_hydra_resolvers():
 def register_hydra_resolvers():
     """Register all the hydra resolvers required for the super-gradients recipes."""
     """Register all the hydra resolvers required for the super-gradients recipes."""
     OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True)
     OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True)
+    OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True)
 
 
 
 
 def pop_arg(arg_name: str, default_value: int = None) -> argparse.Namespace:
 def pop_arg(arg_name: str, default_value: int = None) -> argparse.Namespace:
Discard
@@ -34,3 +34,5 @@ logging.basicConfig(
 )  # Set the default level for all libraries - including 3rd party packages
 )  # Set the default level for all libraries - including 3rd party packages
 
 
 DDP_LOCAL_RANK = -1
 DDP_LOCAL_RANK = -1
+
+INIT_TRAINER = False
Discard
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    1. import json
    2. import hydra
    3. import pkg_resources
    4. from hydra.core.global_hydra import GlobalHydra
    5. from omegaconf import DictConfig
    6. from super_gradients.common.abstractions.abstract_logger import get_logger
    7. logger = get_logger(__name__)
    8. client_enabled = True
    9. try:
    10. from deci_lab_client.client import DeciPlatformClient
    11. from deci_common.data_interfaces.files_data_interface import FilesDataInterface
    12. from deci_lab_client.models import AutoNACFileName
    13. from deci_lab_client import ApiException
    14. except (ImportError, NameError):
    15. client_enabled = False
    16. class DeciClient:
    17. """
    18. A client to deci platform and model zoo.
    19. requires credentials for connection
    20. """
    21. def __init__(self):
    22. if not client_enabled:
    23. logger.error('deci-lab-client or deci-common are not installed. Model cannot be loaded from deci lab.'
    24. 'Please install deci-lab-client>=2.55.0 and deci-common>=3.4.1')
    25. return
    26. self.lab_client = DeciPlatformClient()
    27. GlobalHydra.instance().clear()
    28. self.super_gradients_version = None
    29. try:
    30. self.super_gradients_version = pkg_resources.get_distribution("super_gradients").version
    31. except pkg_resources.DistributionNotFound:
    32. self.super_gradients_version = "3.0.0"
    33. def _get_file(self, model_name: str, file_name: str) -> str:
    34. try:
    35. response = self.lab_client.get_autonac_model_file_link(
    36. model_name=model_name, file_name=file_name, super_gradients_version=self.super_gradients_version
    37. )
    38. download_link = response.data
    39. except ApiException as e:
    40. if e.status == 401:
    41. logger.error("Unauthorized. wrong token or token was not defined. please login to deci-lab-client "
    42. "by calling DeciPlatformClient().login(<token>)")
    43. elif e.status == 400 and e.body is not None and "message" in e.body:
    44. logger.error(f"Deci client: {json.loads(e.body)['message']}")
    45. else:
    46. logger.error(e.body)
    47. return None
    48. return FilesDataInterface.download_temporary_file(file_url=download_link)
    49. def _get_model_cfg(self, model_name: str, cfg_file_name: str) -> DictConfig:
    50. if not client_enabled:
    51. return None
    52. file = self._get_file(model_name=model_name, file_name=cfg_file_name)
    53. if file is None:
    54. return None
    55. split_file = file.split("/")
    56. with hydra.initialize_config_dir(config_dir=f"{'/'.join(split_file[:-1])}/", version_base=None):
    57. cfg = hydra.compose(config_name=split_file[-1])
    58. return cfg
    59. def get_model_arch_params(self, model_name: str) -> DictConfig:
    60. return self.get_model_cfg(model_name, AutoNACFileName.STRUCTURE_YAML)
    61. def get_model_recipe(self, model_name: str) -> DictConfig:
    62. return self.get_model_cfg(model_name, AutoNACFileName.RECIPE_YAML)
    63. def get_model_weights(self, model_name: str) -> str:
    64. if not client_enabled:
    65. return None
    66. return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)
    Discard
    @@ -1,10 +1,18 @@
    +import hydra
    +
     from super_gradients.common import StrictLoad
     from super_gradients.common import StrictLoad
    +from super_gradients.common.plugins.deci_client import DeciClient
     from super_gradients.training import utils as core_utils
     from super_gradients.training import utils as core_utils
     from super_gradients.training.models import SgModule
     from super_gradients.training.models import SgModule
     from super_gradients.training.models.all_architectures import ARCHITECTURES
     from super_gradients.training.models.all_architectures import ARCHITECTURES
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
     from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
    -from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model, load_pretrained_weights, \
    -    read_ckpt_state_dict
    +from super_gradients.training.utils import HpmStruct
    +from super_gradients.training.utils.checkpoint_utils import (
    +    load_checkpoint_to_model,
    +    load_pretrained_weights,
    +    read_ckpt_state_dict,
    +    load_pretrained_weights_local,
    +)
     from super_gradients.common.abstractions.abstract_logger import get_logger
     from super_gradients.common.abstractions.abstract_logger import get_logger
     
     
     logger = get_logger(__name__)
     logger = get_logger(__name__)
    @@ -31,14 +39,34 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No
     
     
             arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
             arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
     
     
    +    remote_model = False
         if isinstance(name, str) and name in ARCHITECTURES.keys():
         if isinstance(name, str) and name in ARCHITECTURES.keys():
             architecture_cls = ARCHITECTURES[name]
             architecture_cls = ARCHITECTURES[name]
             net = architecture_cls(arch_params=arch_params)
             net = architecture_cls(arch_params=arch_params)
    +    elif isinstance(name, str):
    +        logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
    +        deci_client = DeciClient()
    +        _arch_params = deci_client.get_model_arch_params(name)
    +
    +        if _arch_params is not None:
    +            _arch_params = hydra.utils.instantiate(_arch_params)
    +            base_name = _arch_params["model_name"]
    +            _arch_params = HpmStruct(**_arch_params)
    +            architecture_cls = ARCHITECTURES[base_name]
    +            _arch_params.override(**arch_params.to_dict())
    +
    +            net = architecture_cls(arch_params=_arch_params)
    +            remote_model = True
    +        else:
    +            raise ValueError("Unsupported model name " + str(name) + ", see docs or all_architectures.py for supported nets.")
         else:
         else:
    -        raise ValueError(
    -            "Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
    +        raise ValueError("Unsupported model model_name " + str(name) + ", see docs or all_architectures.py for supported nets.")
         if pretrained_weights:
         if pretrained_weights:
    -        load_pretrained_weights(net, name, pretrained_weights)
    +        if remote_model:
    +            weights_path = deci_client.get_model_weights(name)
    +            load_pretrained_weights_local(net, name, weights_path)
    +        else:
    +            load_pretrained_weights(net, name, pretrained_weights)
             if num_classes_new_head != arch_params.num_classes:
             if num_classes_new_head != arch_params.num_classes:
                 net.replace_head(new_num_classes=num_classes_new_head)
                 net.replace_head(new_num_classes=num_classes_new_head)
                 arch_params.num_classes = num_classes_new_head
                 arch_params.num_classes = num_classes_new_head
    Discard
    @@ -265,6 +265,10 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
         unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace('/', '_').replace(' ', '_')
         unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace('/', '_').replace(' ', '_')
         map_location = torch.device('cpu')
         map_location = torch.device('cpu')
         pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
         pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
    +    _load_weights(architecture, model, pretrained_state_dict)
    +
    +
    +def _load_weights(architecture, model, pretrained_state_dict):
         if 'ema_net' in pretrained_state_dict.keys():
         if 'ema_net' in pretrained_state_dict.keys():
             pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
             pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
         solver = _yolox_ckpt_solver if "yolox" in architecture else None
         solver = _yolox_ckpt_solver if "yolox" in architecture else None
    @@ -272,3 +276,19 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
                                                                                   source_ckpt=pretrained_state_dict,
                                                                                   source_ckpt=pretrained_state_dict,
                                                                                   solver=solver)
                                                                                   solver=solver)
         model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
         model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
    +
    +
    +def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
    +
    +    """
    +    Loads pretrained weights from the MODEL_URLS dictionary to model
    +    @param architecture: name of the model's architecture
    +    @param model: model to load pretrinaed weights for
    +    @param pretrained_weights: path tp pretrained weights
    +    @return: None
    +    """
    +
    +    map_location = torch.device('cpu')
    +
    +    pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
    +    _load_weights(architecture, model, pretrained_state_dict)
    Discard