Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#661 documentation on using configuration files

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-608_configuration_files
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
  1. import re
  2. from typing import Union, Tuple, List, Type
  3. from types import TracebackType
  4. import omegaconf
  5. from super_gradients.common.crash_handler.utils import indent_string, fmt_txt, json_str_to_dict
  6. from super_gradients.common.abstractions.abstract_logger import get_logger
  7. logger = get_logger(__name__)
  8. class CrashTip:
  9. """Base class to add tips to exceptions raised while using SuperGradients.
  10. A tip is a more informative message with some suggestions for possible solutions or places to debug.
  11. """
  12. _subclasses: List[Type["CrashTip"]] = []
  13. @classmethod
  14. def get_sub_classes(cls) -> List[Type["CrashTip"]]:
  15. """Get all the classes inheriting from CrashTip"""
  16. return cls._subclasses
  17. def __init_subclass__(cls):
  18. """Register any class inheriting from CrashTip"""
  19. CrashTip._subclasses.append(cls)
  20. @classmethod
  21. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> bool:
  22. """
  23. Check if this tip is relevant.
  24. Beside the class, the input params are as returned by sys.exc_info():
  25. :param cls: Class inheriting from CrashTip
  26. :param exc_type: Type of exception
  27. :param exc_value: Exception
  28. :param exc_traceback: Traceback
  29. :return: True if the current class can help with the exception
  30. """
  31. raise NotImplementedError
  32. @classmethod
  33. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  34. """
  35. Provide a customized tip for the exception, combining explanation and solution.
  36. Beside the class, the input params are as returned by sys.exc_info():
  37. :param cls: Class inheriting from CrashTip
  38. :param exc_type: Type of exception
  39. :param exc_value: Exception
  40. :param exc_traceback: Traceback
  41. :return: Tip
  42. """
  43. raise NotImplementedError
  44. @classmethod
  45. def get_message(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> Union[None, str]:
  46. """
  47. Wrap the tip in a nice message.
  48. Beside the class, the input params are as returned by sys.exc_info():
  49. :param cls: Class inheriting from CrashTip
  50. :param exc_type: Type of exception
  51. :param exc_value: Exception
  52. :param exc_traceback: Traceback
  53. :return: Tip
  54. """
  55. try:
  56. def format_tip(tip_index: int, tip: str):
  57. first_sentence, *following_sentences = tip.split("\n")
  58. first_sentence = f"{tip_index+1}. {first_sentence}"
  59. following_sentences = [f" {sentence}" for sentence in following_sentences]
  60. return "\n".join([first_sentence] + following_sentences)
  61. tips: List[str] = cls._get_tips(exc_type, exc_value, exc_traceback)
  62. formatted_tips: str = "\n".join([format_tip(i, tip) for i, tip in enumerate(tips)])
  63. message = (
  64. "═══════════════════════════════════════════╦═════════════════════════╦════════════════════════════════════════════════════════════\n"
  65. " ║ SuperGradient Crash tip ║ \n"
  66. " ╚═════════════════════════╝ \n"
  67. f"{fmt_txt('Something went wrong!', color='red', bold=True)} You can find below potential solution(s) to this error: \n\n"
  68. f"{formatted_tips}\n"
  69. f"{len(tips)+1}. If the proposed solution(s) did not help, feel free to contact the SuperGradient team or to open a ticket on "
  70. f"https://github.com/Deci-AI/super-gradients/issues/new/choose\n\n"
  71. "see the trace above...\n"
  72. "══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════\n"
  73. )
  74. return "\n" + message
  75. except Exception:
  76. # It is important that the crash tip does not crash itself, because it is called atexit!
  77. # Otherwise, the user would get a crash on top of another crash and this would be extremly confusing
  78. return None
  79. class TorchCudaMissingTip(CrashTip):
  80. @classmethod
  81. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> bool:
  82. pattern = "symbol cublasLtHSHMatmulAlgoInit version"
  83. return isinstance(exc_value, OSError) and pattern in str(exc_value)
  84. @classmethod
  85. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  86. tip = (
  87. f"This error may indicate {fmt_txt('CUDA libraries version conflict', color='red')} (When Torchvision & Torch are installed for different "
  88. f"CUDA versions) or the {fmt_txt('absence of CUDA support in PyTorch', color='red')}.\n"
  89. "To fix this you can:\n"
  90. f" a. Make sure to {fmt_txt('uninstall torch, torchvision', color='green')}\n"
  91. f" b. {fmt_txt('Install the torch version', color='green')} that respects your os & compute platform "
  92. f"{fmt_txt('following the instruction from https://pytorch.org/', color='green')}"
  93. )
  94. return [tip]
  95. class RecipeFactoryFormatTip(CrashTip):
  96. @classmethod
  97. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> bool:
  98. pattern = "Malformed object definition in configuration. Expecting either a string of object type or a single entry dictionary"
  99. return isinstance(exc_value, RuntimeError) and pattern in str(exc_value)
  100. @classmethod
  101. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  102. factory_name, params_dict = RecipeFactoryFormatTip._get_factory_with_params(exc_value)
  103. formatted_factory_name = fmt_txt(factory_name, bold=True, color="green")
  104. params_in_yaml = "\n".join(f" {k}: {v}" for k, v in params_dict.items())
  105. user_yaml = f"- {factory_name}:\n" + params_in_yaml
  106. formatted_user_yaml = fmt_txt(user_yaml, indent=4, color="red")
  107. correct_yaml = f"- {factory_name}:\n" + indent_string(params_in_yaml, indent_size=2)
  108. formatted_correct_yaml = fmt_txt(correct_yaml, indent=4, color="green")
  109. tip = f"There is an indentation error in the recipe, while creating {formatted_factory_name}.\n"
  110. tip += "If your wrote this in your recipe:\n"
  111. tip += f"{formatted_user_yaml}\n"
  112. tip += "Please change it to:\n"
  113. tip += f"{formatted_correct_yaml}"
  114. tips = [tip]
  115. return tips
  116. @staticmethod
  117. def _get_factory_with_params(exc_value: Exception) -> Tuple[str, dict]:
  118. """Utility function to extract useful features from the exception.
  119. :return: Name of the factory that (we assume) was not correctly defined
  120. :return: Parameters that are passed to that factory
  121. """
  122. description = str(exc_value)
  123. params_dict = re.search(r"received: (.*?)$", description).group(1)
  124. params_dict = json_str_to_dict(params_dict)
  125. factory_name = next(iter(params_dict))
  126. params_dict.pop(factory_name)
  127. return factory_name, params_dict
  128. class DDPNotInitializedTip(CrashTip):
  129. """Note: I think that this should be caught within the code instead"""
  130. @classmethod
  131. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType):
  132. expected_str = "Default process group has not been initialized, please make sure to call init_process_group."
  133. return isinstance(exc_value, RuntimeError) and expected_str in str(exc_value)
  134. @classmethod
  135. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  136. tip = (
  137. "Your environment was not setup correctly for DDP.\n"
  138. "Please run at the beginning of your script:\n"
  139. f">>> {fmt_txt('from super_gradients.training.utils.distributed_training_utils import setup_device', color='green')}\n"
  140. f">>> {fmt_txt('from super_gradients.common.data_types.enum import MultiGPUMode', color='green')}\n"
  141. f">>> {fmt_txt('setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)', color='green')}"
  142. )
  143. return [tip]
  144. class WrongHydraVersionTip(CrashTip):
  145. """Note: I think that this should be caught within the code instead"""
  146. @classmethod
  147. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType):
  148. expected_str = "__init__() got an unexpected keyword argument 'version_base'"
  149. return isinstance(exc_value, TypeError) and expected_str == str(exc_value)
  150. @classmethod
  151. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  152. import hydra
  153. tip = (
  154. f"{fmt_txt(f'hydra=={hydra.__version__}', color='red')} is not supported by SuperGradients. "
  155. f"Please run {fmt_txt('pip install hydra-core==1.2.0', color='green')}"
  156. )
  157. return [tip]
  158. class InterpolationKeyErrorTip(CrashTip):
  159. @classmethod
  160. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType):
  161. expected_str = "Interpolation key "
  162. return isinstance(exc_value, omegaconf.errors.InterpolationKeyError) and expected_str in str(exc_value)
  163. @classmethod
  164. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  165. variable = re.search("'(.*?)'", str(exc_value)).group(1)
  166. tip = (
  167. f"It looks like you encountered an error related to interpolation of the variable '{variable}'.\n"
  168. "It's possible that this error is caused by not using the full path of the variable in your subfolder configuration.\n"
  169. f"Please make sure that you are referring to the variable using the "
  170. f"{fmt_txt('full path starting from the main configuration file', color='green')}.\n"
  171. f"Try to replace '{fmt_txt(f'${{{variable}}}', color='red')}' with '{fmt_txt(f'${{full.path.to.{variable}}}', color='green')}', \n"
  172. f" where 'full.path.to' is the actual path to reach '{variable}', starting from the root configuration file.\n"
  173. f"Example: '{fmt_txt('${dataset_params.train_dataloader_params.batch_size}', color='green')}' "
  174. f"instead of '{fmt_txt('${train_dataloader_params.batch_size}', color='red')}'.\n"
  175. )
  176. return [tip]
  177. def get_relevant_crash_tip_message(exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> Union[None, str]:
  178. """Get a CrashTip class if relevant for input exception"""
  179. for crash_tip in CrashTip.get_sub_classes():
  180. if crash_tip.is_relevant(exc_type, exc_value, exc_traceback):
  181. return crash_tip.get_message(exc_type, exc_value, exc_traceback)
  182. return None
Discard
Tip!

Press p or to see the previous file or, n or to see the next file