Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#549 Feature/infra 1481 call integration tests

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/infra-1481_call_integration_tests
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
  1. import re
  2. from typing import Union, Tuple, List, Type
  3. from types import TracebackType
  4. from super_gradients.common.crash_handler.utils import indent_string, fmt_txt, json_str_to_dict
  5. from super_gradients.common.abstractions.abstract_logger import get_logger
  6. logger = get_logger(__name__)
  7. class CrashTip:
  8. """Base class to add tips to exceptions raised while using SuperGradients.
  9. A tip is a more informative message with some suggestions for possible solutions or places to debug.
  10. """
  11. @classmethod
  12. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> bool:
  13. """
  14. Check if this tip is relevant.
  15. Beside the class, the input params are as returned by sys.exc_info():
  16. :param cls: Class inheriting from CrashTip
  17. :param exc_type: Type of exception
  18. :param exc_value: Exception
  19. :param exc_traceback: Traceback
  20. :return: True if the current class can help with the exception
  21. """
  22. raise NotImplementedError
  23. @classmethod
  24. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  25. """
  26. Provide a customized tip for the exception, combining explanation and solution.
  27. Beside the class, the input params are as returned by sys.exc_info():
  28. :param cls: Class inheriting from CrashTip
  29. :param exc_type: Type of exception
  30. :param exc_value: Exception
  31. :param exc_traceback: Traceback
  32. :return: Tip
  33. """
  34. raise NotImplementedError
  35. @classmethod
  36. def get_message(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> Union[None, str]:
  37. """
  38. Wrap the tip in a nice message.
  39. Beside the class, the input params are as returned by sys.exc_info():
  40. :param cls: Class inheriting from CrashTip
  41. :param exc_type: Type of exception
  42. :param exc_value: Exception
  43. :param exc_traceback: Traceback
  44. :return: Tip
  45. """
  46. try:
  47. def format_tip(tip_index: int, tip: str):
  48. first_sentence, *following_sentences = tip.split("\n")
  49. first_sentence = f"{tip_index+1}. {first_sentence}"
  50. following_sentences = [f" {sentence}" for sentence in following_sentences]
  51. return "\n".join([first_sentence] + following_sentences)
  52. tips: List[str] = cls._get_tips(exc_type, exc_value, exc_traceback)
  53. formatted_tips: str = "\n".join([format_tip(i, tip) for i, tip in enumerate(tips)])
  54. message = (
  55. "═══════════════════════════════════════════╦═════════════════════════╦════════════════════════════════════════════════════════════\n"
  56. " ║ SuperGradient Crash tip ║ \n"
  57. " ╚═════════════════════════╝ \n"
  58. f"{fmt_txt('Something went wrong!', color='red', bold=True)} You can find below potential solution(s) to this error: \n\n"
  59. f"{formatted_tips}\n"
  60. f"{len(tips)+1}. If the proposed solution(s) did not help, feel free to contact the SuperGradient team or to open a ticket on "
  61. f"https://github.com/Deci-AI/super-gradients/issues/new/choose\n\n"
  62. "see the trace above...\n"
  63. "══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════\n"
  64. )
  65. return "\n" + message
  66. except Exception:
  67. # It is important that the crash tip does not crash itself, because it is called atexit!
  68. # Otherwise, the user would get a crash on top of another crash and this would be extremly confusing
  69. return None
  70. class TorchCudaMissingTip(CrashTip):
  71. @classmethod
  72. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> bool:
  73. pattern = "symbol cublasLtHSHMatmulAlgoInit version"
  74. return isinstance(exc_value, OSError) and pattern in str(exc_value)
  75. @classmethod
  76. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  77. tip = (
  78. f"This error may indicate {fmt_txt('CUDA libraries version conflict', color='red')} (When Torchvision & Torch are installed for different "
  79. f"CUDA versions) or the {fmt_txt('absence of CUDA support in PyTorch', color='red')}.\n"
  80. "To fix this you can:\n"
  81. f" a. Make sure to {fmt_txt('uninstall torch, torchvision', color='green')}\n"
  82. f" b. {fmt_txt('Install the torch version', color='green')} that respects your os & compute platform "
  83. f"{fmt_txt('following the instruction from https://pytorch.org/', color='green')}"
  84. )
  85. return [tip]
  86. class RecipeFactoryFormatTip(CrashTip):
  87. @classmethod
  88. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> bool:
  89. pattern = "Malformed object definition in configuration. Expecting either a string of object type or a single entry dictionary"
  90. return isinstance(exc_value, RuntimeError) and pattern in str(exc_value)
  91. @classmethod
  92. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  93. factory_name, params_dict = RecipeFactoryFormatTip._get_factory_with_params(exc_value)
  94. formatted_factory_name = fmt_txt(factory_name, bold=True, color="green")
  95. params_in_yaml = "\n".join(f" {k}: {v}" for k, v in params_dict.items())
  96. user_yaml = f"- {factory_name}:\n" + params_in_yaml
  97. formatted_user_yaml = fmt_txt(user_yaml, indent=4, color="red")
  98. correct_yaml = f"- {factory_name}:\n" + indent_string(params_in_yaml, indent_size=2)
  99. formatted_correct_yaml = fmt_txt(correct_yaml, indent=4, color="green")
  100. tip = f"There is an indentation error in the recipe, while creating {formatted_factory_name}.\n"
  101. tip += "If your wrote this in your recipe:\n"
  102. tip += f"{formatted_user_yaml}\n"
  103. tip += "Please change it to:\n"
  104. tip += f"{formatted_correct_yaml}"
  105. tips = [tip]
  106. return tips
  107. @staticmethod
  108. def _get_factory_with_params(exc_value: Exception) -> Tuple[str, dict]:
  109. """Utility function to extract useful features from the exception.
  110. :return: Name of the factory that (we assume) was not correctly defined
  111. :return: Parameters that are passed to that factory
  112. """
  113. description = str(exc_value)
  114. params_dict = re.search(r"received: (.*?)$", description).group(1)
  115. params_dict = json_str_to_dict(params_dict)
  116. factory_name = next(iter(params_dict))
  117. params_dict.pop(factory_name)
  118. return factory_name, params_dict
  119. class DDPNotInitializedTip(CrashTip):
  120. """Note: I think that this should be caught within the code instead"""
  121. @classmethod
  122. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType):
  123. expected_str = "Default process group has not been initialized, please make sure to call init_process_group."
  124. return isinstance(exc_value, RuntimeError) and expected_str in str(exc_value)
  125. @classmethod
  126. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  127. tip = (
  128. "Your environment was not setup correctly for DDP.\n"
  129. "Please run at the beginning of your script:\n"
  130. f">>> {fmt_txt('from super_gradients.training.utils.distributed_training_utils import setup_device', color='green')}\n"
  131. f">>> {fmt_txt('from super_gradients.common.data_types.enum import MultiGPUMode', color='green')}\n"
  132. f">>> {fmt_txt('setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)', color='green')}"
  133. )
  134. return [tip]
  135. class WrongHydraVersionTip(CrashTip):
  136. """Note: I think that this should be caught within the code instead"""
  137. @classmethod
  138. def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType):
  139. expected_str = "__init__() got an unexpected keyword argument 'version_base'"
  140. return isinstance(exc_value, TypeError) and expected_str == str(exc_value)
  141. @classmethod
  142. def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]:
  143. import hydra
  144. tip = (
  145. f"{fmt_txt(f'hydra=={hydra.__version__}', color='red')} is not supported by SuperGradients. "
  146. f"Please run {fmt_txt('pip install hydra-core==1.2.0', color='green')}"
  147. )
  148. return [tip]
  149. # /!\ Only the CrashTips classes listed below will be used !! /!\
  150. ALL_CRASH_TIPS: List[Type[CrashTip]] = [TorchCudaMissingTip, RecipeFactoryFormatTip, DDPNotInitializedTip, WrongHydraVersionTip]
  151. def get_relevant_crash_tip_message(exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> Union[None, str]:
  152. """Get a CrashTip class if relevant for input exception"""
  153. for crash_tip in ALL_CRASH_TIPS:
  154. if crash_tip.is_relevant(exc_type, exc_value, exc_traceback):
  155. return crash_tip.get_message(exc_type, exc_value, exc_traceback)
  156. return None
Discard
Tip!

Press p or to see the previous file or, n or to see the next file