Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  1. import json
  2. import hydra
  3. import pkg_resources
  4. from hydra.core.global_hydra import GlobalHydra
  5. from omegaconf import DictConfig
  6. from super_gradients.common.abstractions.abstract_logger import get_logger
  7. logger = get_logger(__name__)
  8. client_enabled = True
  9. try:
  10. from deci_lab_client.client import DeciPlatformClient
  11. from deci_common.data_interfaces.files_data_interface import FilesDataInterface
  12. from deci_lab_client.models import AutoNACFileName
  13. from deci_lab_client import ApiException
  14. except (ImportError, NameError):
  15. client_enabled = False
  16. class DeciClient:
  17. """
  18. A client to deci platform and model zoo.
  19. requires credentials for connection
  20. """
  21. def __init__(self):
  22. if not client_enabled:
  23. logger.error('deci-lab-client or deci-common are not installed. Model cannot be loaded from deci lab.'
  24. 'Please install deci-lab-client>=2.55.0 and deci-common>=3.4.1')
  25. return
  26. self.lab_client = DeciPlatformClient()
  27. GlobalHydra.instance().clear()
  28. self.super_gradients_version = None
  29. try:
  30. self.super_gradients_version = pkg_resources.get_distribution("super_gradients").version
  31. except pkg_resources.DistributionNotFound:
  32. self.super_gradients_version = "3.0.0"
  33. def _get_file(self, model_name: str, file_name: str) -> str:
  34. try:
  35. response = self.lab_client.get_autonac_model_file_link(
  36. model_name=model_name, file_name=file_name, super_gradients_version=self.super_gradients_version
  37. )
  38. download_link = response.data
  39. except ApiException as e:
  40. if e.status == 401:
  41. logger.error("Unauthorized. wrong token or token was not defined. please login to deci-lab-client "
  42. "by calling DeciPlatformClient().login(<token>)")
  43. elif e.status == 400 and e.body is not None and "message" in e.body:
  44. logger.error(f"Deci client: {json.loads(e.body)['message']}")
  45. else:
  46. logger.error(e.body)
  47. return None
  48. return FilesDataInterface.download_temporary_file(file_url=download_link)
  49. def _get_model_cfg(self, model_name: str, cfg_file_name: str) -> DictConfig:
  50. if not client_enabled:
  51. return None
  52. file = self._get_file(model_name=model_name, file_name=cfg_file_name)
  53. if file is None:
  54. return None
  55. split_file = file.split("/")
  56. with hydra.initialize_config_dir(config_dir=f"{'/'.join(split_file[:-1])}/", version_base=None):
  57. cfg = hydra.compose(config_name=split_file[-1])
  58. return cfg
  59. def get_model_arch_params(self, model_name: str) -> DictConfig:
  60. return self._get_model_cfg(model_name, AutoNACFileName.STRUCTURE_YAML)
  61. def get_model_recipe(self, model_name: str) -> DictConfig:
  62. return self._get_model_cfg(model_name, AutoNACFileName.RECIPE_YAML)
  63. def get_model_weights(self, model_name: str) -> str:
  64. if not client_enabled:
  65. return None
  66. return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file