Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
  1. import json
  2. import sys
  3. from zipfile import ZipFile
  4. import hydra
  5. import importlib.util
  6. import os
  7. import pkg_resources
  8. from hydra.core.global_hydra import GlobalHydra
  9. from omegaconf import DictConfig
  10. from torch import nn
  11. from super_gradients.common.abstractions.abstract_logger import get_logger
  12. from super_gradients.training.utils.hydra_utils import normalize_path
  13. logger = get_logger(__name__)
  14. client_enabled = True
  15. try:
  16. from deci_lab_client.client import DeciPlatformClient
  17. from deci_common.data_interfaces.files_data_interface import FilesDataInterface
  18. from deci_lab_client.models import AutoNACFileName
  19. from deci_lab_client import ApiException
  20. except (ImportError, NameError):
  21. client_enabled = False
  22. class DeciClient:
  23. """
  24. A client to deci platform and model zoo.
  25. requires credentials for connection
  26. """
  27. def __init__(self):
  28. if not client_enabled:
  29. logger.error(
  30. "deci-lab-client or deci-common are not installed. Model cannot be loaded from deci lab."
  31. "Please install deci-lab-client>=2.55.0 and deci-common>=3.4.1"
  32. )
  33. return
  34. self.lab_client = DeciPlatformClient()
  35. GlobalHydra.instance().clear()
  36. self.super_gradients_version = None
  37. try:
  38. self.super_gradients_version = pkg_resources.get_distribution("super_gradients").version
  39. except pkg_resources.DistributionNotFound:
  40. self.super_gradients_version = "3.0.2"
  41. def _get_file(self, model_name: str, file_name: str) -> str:
  42. try:
  43. response = self.lab_client.get_autonac_model_file_link(
  44. model_name=model_name, file_name=file_name, super_gradients_version=self.super_gradients_version
  45. )
  46. download_link = response.data
  47. except ApiException as e:
  48. if e.status == 401:
  49. logger.error(
  50. "Unauthorized. wrong token or token was not defined. please login to deci-lab-client " "by calling DeciPlatformClient().login(<token>)"
  51. )
  52. elif e.status == 400 and e.body is not None and "message" in e.body:
  53. logger.error(f"Deci client: {json.loads(e.body)['message']}")
  54. else:
  55. logger.debug(e.body)
  56. return None
  57. return FilesDataInterface.download_temporary_file(file_url=download_link)
  58. def _get_model_cfg(self, model_name: str, cfg_file_name: str) -> DictConfig:
  59. if not client_enabled:
  60. return None
  61. file = self._get_file(model_name=model_name, file_name=cfg_file_name)
  62. if file is None:
  63. return None
  64. split_file = file.split("/")
  65. with hydra.initialize_config_dir(config_dir=normalize_path(f"{'/'.join(split_file[:-1])}/"), version_base=None):
  66. cfg = hydra.compose(config_name=split_file[-1])
  67. return cfg
  68. def get_model_arch_params(self, model_name: str) -> DictConfig:
  69. return self._get_model_cfg(model_name, AutoNACFileName.STRUCTURE_YAML)
  70. def get_model_recipe(self, model_name: str) -> DictConfig:
  71. return self._get_model_cfg(model_name, AutoNACFileName.RECIPE_YAML)
  72. def get_model_weights(self, model_name: str) -> str:
  73. if not client_enabled:
  74. return None
  75. return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)
  76. def download_and_load_model_additional_code(self, model_name: str, target_path: str, package_name: str = "deci_model_code") -> None:
  77. """
  78. try to download code files for this model.
  79. if found, code files will be placed in the target_path/package_name and imported dynamically
  80. """
  81. file = self._get_file(model_name=model_name, file_name=AutoNACFileName.CODE_ZIP)
  82. package_path = os.path.join(target_path, package_name)
  83. if file is not None:
  84. # crete the directory
  85. os.makedirs(package_path, exist_ok=True)
  86. # extract code files
  87. with ZipFile(file) as zipfile:
  88. zipfile.extractall(package_path)
  89. # add an init file that imports all code files
  90. with open(os.path.join(package_path, "__init__.py"), "w") as init_file:
  91. all_str = "\n\n__all__ = ["
  92. for code_file in os.listdir(path=package_path):
  93. if code_file.endswith(".py") and not code_file.startswith("__init__"):
  94. init_file.write(f'import {code_file.replace(".py", "")}\n')
  95. all_str += f'"{code_file.replace(".py", "")}", '
  96. all_str += "]\n\n"
  97. init_file.write(all_str)
  98. # include in path and import
  99. sys.path.insert(1, package_path)
  100. importlib.import_module(package_name)
  101. logger.info(
  102. f"*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n"
  103. f"These files will be downloaded to the same location each time the model is fetched from the deci-client.\n"
  104. f"you can override this by passing models.get(... download_required_code=False) and importing the files yourself"
  105. )
  106. def upload_model(self, model: nn.Module, model_meta_data, optimization_request_form):
  107. """
  108. This function will upload the trained model to the Deci Lab
  109. Args:
  110. model: The resulting model from the training process
  111. model_meta_data: Metadata to accompany the model
  112. optimization_request_form: The optimization parameters
  113. """
  114. self.lab_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))
  115. self.lab_client.add_model(
  116. add_model_request=model_meta_data,
  117. optimization_request=optimization_request_form,
  118. local_loaded_model=model,
  119. )
Discard
Tip!

Press p or to see the previous file or, n or to see the next file