Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    1. import json
    2. import hydra
    3. import pkg_resources
    4. from hydra.core.global_hydra import GlobalHydra
    5. from omegaconf import DictConfig
    6. from super_gradients.common.abstractions.abstract_logger import get_logger
    7. logger = get_logger(__name__)
    8. client_enabled = True
    9. try:
    10. from deci_lab_client.client import DeciPlatformClient
    11. from deci_common.data_interfaces.files_data_interface import FilesDataInterface
    12. from deci_lab_client.models import AutoNACFileName
    13. from deci_lab_client import ApiException
    14. except (ImportError, NameError):
    15. client_enabled = False
    16. class DeciClient:
    17. """
    18. A client to deci platform and model zoo.
    19. requires credentials for connection
    20. """
    21. def __init__(self):
    22. if not client_enabled:
    23. logger.error('deci-lab-client or deci-common are not installed. Model cannot be loaded from deci lab.'
    24. 'Please install deci-lab-client>=2.55.0 and deci-common>=3.4.1')
    25. return
    26. self.lab_client = DeciPlatformClient()
    27. GlobalHydra.instance().clear()
    28. self.super_gradients_version = None
    29. try:
    30. self.super_gradients_version = pkg_resources.get_distribution("super_gradients").version
    31. except pkg_resources.DistributionNotFound:
    32. self.super_gradients_version = "3.0.0"
    33. def _get_file(self, model_name: str, file_name: str) -> str:
    34. try:
    35. response = self.lab_client.get_autonac_model_file_link(
    36. model_name=model_name, file_name=file_name, super_gradients_version=self.super_gradients_version
    37. )
    38. download_link = response.data
    39. except ApiException as e:
    40. if e.status == 401:
    41. logger.error("Unauthorized. wrong token or token was not defined. please login to deci-lab-client "
    42. "by calling DeciPlatformClient().login(<token>)")
    43. elif e.status == 400 and e.body is not None and "message" in e.body:
    44. logger.error(f"Deci client: {json.loads(e.body)['message']}")
    45. else:
    46. logger.error(e.body)
    47. return None
    48. return FilesDataInterface.download_temporary_file(file_url=download_link)
    49. def _get_model_cfg(self, model_name: str, cfg_file_name: str) -> DictConfig:
    50. if not client_enabled:
    51. return None
    52. file = self._get_file(model_name=model_name, file_name=cfg_file_name)
    53. if file is None:
    54. return None
    55. split_file = file.split("/")
    56. with hydra.initialize_config_dir(config_dir=f"{'/'.join(split_file[:-1])}/", version_base=None):
    57. cfg = hydra.compose(config_name=split_file[-1])
    58. return cfg
    59. def get_model_arch_params(self, model_name: str) -> DictConfig:
    60. return self.get_model_cfg(model_name, AutoNACFileName.STRUCTURE_YAML)
    61. def get_model_recipe(self, model_name: str) -> DictConfig:
    62. return self.get_model_cfg(model_name, AutoNACFileName.RECIPE_YAML)
    63. def get_model_weights(self, model_name: str) -> str:
    64. if not client_enabled:
    65. return None
    66. return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)
    Discard