Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#475 Feature/sg 000 clean start prints

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_clean_start_prints
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  1. import json
  2. import sys
  3. from zipfile import ZipFile
  4. import hydra
  5. import importlib.util
  6. import os
  7. import pkg_resources
  8. from hydra.core.global_hydra import GlobalHydra
  9. from omegaconf import DictConfig
  10. from super_gradients.common.abstractions.abstract_logger import get_logger
  11. from super_gradients.training.utils.hydra_utils import normalize_path
  12. logger = get_logger(__name__)
  13. client_enabled = True
  14. try:
  15. from deci_lab_client.client import DeciPlatformClient
  16. from deci_common.data_interfaces.files_data_interface import FilesDataInterface
  17. from deci_lab_client.models import AutoNACFileName
  18. from deci_lab_client import ApiException
  19. except (ImportError, NameError):
  20. client_enabled = False
  21. class DeciClient:
  22. """
  23. A client to deci platform and model zoo.
  24. requires credentials for connection
  25. """
  26. def __init__(self):
  27. if not client_enabled:
  28. logger.error('deci-lab-client or deci-common are not installed. Model cannot be loaded from deci lab.'
  29. 'Please install deci-lab-client>=2.55.0 and deci-common>=3.4.1')
  30. return
  31. self.lab_client = DeciPlatformClient()
  32. GlobalHydra.instance().clear()
  33. self.super_gradients_version = None
  34. try:
  35. self.super_gradients_version = pkg_resources.get_distribution("super_gradients").version
  36. except pkg_resources.DistributionNotFound:
  37. self.super_gradients_version = "3.0.0"
  38. def _get_file(self, model_name: str, file_name: str) -> str:
  39. try:
  40. response = self.lab_client.get_autonac_model_file_link(
  41. model_name=model_name, file_name=file_name, super_gradients_version=self.super_gradients_version
  42. )
  43. download_link = response.data
  44. except ApiException as e:
  45. if e.status == 401:
  46. logger.error("Unauthorized. wrong token or token was not defined. please login to deci-lab-client "
  47. "by calling DeciPlatformClient().login(<token>)")
  48. elif e.status == 400 and e.body is not None and "message" in e.body:
  49. logger.error(f"Deci client: {json.loads(e.body)['message']}")
  50. else:
  51. logger.error(e.body)
  52. return None
  53. return FilesDataInterface.download_temporary_file(file_url=download_link)
  54. def _get_model_cfg(self, model_name: str, cfg_file_name: str) -> DictConfig:
  55. if not client_enabled:
  56. return None
  57. file = self._get_file(model_name=model_name, file_name=cfg_file_name)
  58. if file is None:
  59. return None
  60. split_file = file.split("/")
  61. with hydra.initialize_config_dir(config_dir=normalize_path(f"{'/'.join(split_file[:-1])}/"), version_base=None):
  62. cfg = hydra.compose(config_name=split_file[-1])
  63. return cfg
  64. def get_model_arch_params(self, model_name: str) -> DictConfig:
  65. return self._get_model_cfg(model_name, AutoNACFileName.STRUCTURE_YAML)
  66. def get_model_recipe(self, model_name: str) -> DictConfig:
  67. return self._get_model_cfg(model_name, AutoNACFileName.RECIPE_YAML)
  68. def get_model_weights(self, model_name: str) -> str:
  69. if not client_enabled:
  70. return None
  71. return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)
  72. def download_and_load_model_additional_code(self, model_name: str, target_path: str, package_name: str = "deci_model_code") -> None:
  73. """
  74. try to download code files for this model.
  75. if found, code files will be placed in the target_path/package_name and imported dynamically
  76. """
  77. file = self._get_file(model_name=model_name, file_name=AutoNACFileName.CODE_ZIP)
  78. package_path = os.path.join(target_path, package_name)
  79. if file is not None:
  80. # crete the directory
  81. os.makedirs(package_path, exist_ok=True)
  82. # extract code files
  83. with ZipFile(file) as zipfile:
  84. zipfile.extractall(package_path)
  85. # add an init file that imports all code files
  86. with open(os.path.join(package_path, '__init__.py'), 'w') as init_file:
  87. all_str = '\n\n__all__ = ['
  88. for code_file in os.listdir(path=package_path):
  89. if code_file.endswith(".py") and not code_file.startswith("__init__"):
  90. init_file.write(f'import {code_file.replace(".py", "")}\n')
  91. all_str += f'"{code_file.replace(".py", "")}", '
  92. all_str += "]\n\n"
  93. init_file.write(all_str)
  94. # include in path and import
  95. sys.path.insert(1, package_path)
  96. importlib.import_module(package_name)
  97. logger.info(f'*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n'
  98. f'These files will be downloaded to the same location each time the model is fetched from the deci-client.\n'
  99. f'you can override this by passing models.get(... download_required_code=False) and importing the files yourself')
Discard
Tip!

Press p or to see the previous file or, n or to see the next file