Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#546 Features/sg 409 check all params used

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:features/SG-409-check-all-params-used
@@ -1,14 +1,14 @@
 from torch import nn
 from torch import nn
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.sg_module import SgModule
+from super_gradients.training.utils import get_param
 
 
 
 
 def create_conv_module(in_channels, out_channels, kernel_size=3, stride=1):
 def create_conv_module(in_channels, out_channels, kernel_size=3, stride=1):
     padding = (kernel_size - 1) // 2
     padding = (kernel_size - 1) // 2
     nn_sequential_module = nn.Sequential()
     nn_sequential_module = nn.Sequential()
-    nn_sequential_module.add_module('Conv2d', nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
-                                                        stride=stride, padding=padding, bias=False))
-    nn_sequential_module.add_module('BatchNorm2d', nn.BatchNorm2d(out_channels))
-    nn_sequential_module.add_module('LeakyRelu', nn.LeakyReLU())
+    nn_sequential_module.add_module("Conv2d", nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False))
+    nn_sequential_module.add_module("BatchNorm2d", nn.BatchNorm2d(out_channels))
+    nn_sequential_module.add_module("LeakyRelu", nn.LeakyReLU())
 
 
     return nn_sequential_module
     return nn_sequential_module
 
 
@@ -72,19 +72,19 @@ class Darknet53(Darknet53Base):
         super(Darknet53, self).__init__()
         super(Darknet53, self).__init__()
 
 
         # IN ORDER TO ALLOW PASSING PARAMETERS WITH ARCH_PARAMS BUT NOT BREAK YOLOV3 INTEGRATION
         # IN ORDER TO ALLOW PASSING PARAMETERS WITH ARCH_PARAMS BUT NOT BREAK YOLOV3 INTEGRATION
-        self.backbone_mode = arch_params.backbone_mode if hasattr(arch_params, 'backbone_mode') else backbone_mode
-        self.num_classes = arch_params.num_classes if hasattr(arch_params, 'num_classes') else num_classes
+        self.backbone_mode = get_param(arch_params, "backbone_mode", backbone_mode)
+        self.num_classes = get_param(arch_params, "num_classes", num_classes)
 
 
         if not self.backbone_mode:
         if not self.backbone_mode:
             # IF NOT USED AS A BACKEND BUT AS A CLASSIFIER WE ADD THE CLASSIFICATION LAYERS
             # IF NOT USED AS A BACKEND BUT AS A CLASSIFIER WE ADD THE CLASSIFICATION LAYERS
             if self.num_classes is not None:
             if self.num_classes is not None:
                 nn_sequential_block = nn.Sequential()
                 nn_sequential_block = nn.Sequential()
-                nn_sequential_block.add_module('global_avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
-                nn_sequential_block.add_module('view', ViewModule(1024))
-                nn_sequential_block.add_module('fc', nn.Linear(1024, self.num_classes))
+                nn_sequential_block.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1, 1)))
+                nn_sequential_block.add_module("view", ViewModule(1024))
+                nn_sequential_block.add_module("fc", nn.Linear(1024, self.num_classes))
                 self.modules_list.append(nn_sequential_block)
                 self.modules_list.append(nn_sequential_block)
             else:
             else:
-                raise ValueError('num_classes must be specified to use Darknet53 as a classifier')
+                raise ValueError("num_classes must be specified to use Darknet53 as a classifier")
 
 
     def get_modules_list(self):
     def get_modules_list(self):
         return self.modules_list
         return self.modules_list
Discard
@@ -1,6 +1,7 @@
 from super_gradients.training.utils.utils import Timer, HpmStruct, WrappedModel, convert_to_tensor, get_param, tensor_container_to_device, random_seed
 from super_gradients.training.utils.utils import Timer, HpmStruct, WrappedModel, convert_to_tensor, get_param, tensor_container_to_device, random_seed
 from super_gradients.training.utils.checkpoint_utils import adapt_state_dict_to_fit_model_layer_names, raise_informative_runtime_error
 from super_gradients.training.utils.checkpoint_utils import adapt_state_dict_to_fit_model_layer_names, raise_informative_runtime_error
 from super_gradients.training.utils.version_utils import torch_version_is_greater_or_equal
 from super_gradients.training.utils.version_utils import torch_version_is_greater_or_equal
+from super_gradients.training.utils.config_utils import raise_if_unused_params, warn_if_unused_params
 
 
 __all__ = [
 __all__ = [
     "Timer",
     "Timer",
@@ -13,4 +14,6 @@ __all__ = [
     "raise_informative_runtime_error",
     "raise_informative_runtime_error",
     "random_seed",
     "random_seed",
     "torch_version_is_greater_or_equal",
     "torch_version_is_greater_or_equal",
+    "raise_if_unused_params",
+    "warn_if_unused_params",
 ]
 ]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
  1. import abc
  2. from collections import defaultdict
  3. from typing import Mapping, Iterable, Set, Union
  4. __all__ = ["raise_if_unused_params", "warn_if_unused_params", "UnusedConfigParamException"]
  5. from omegaconf import ListConfig, DictConfig
  6. from super_gradients.common.abstractions.abstract_logger import get_logger
  7. from super_gradients.training.utils import HpmStruct
  8. logger = get_logger(__name__)
  9. class UnusedConfigParamException(Exception):
  10. pass
  11. class AccessCounterMixin:
  12. """
  13. Implements access counting mechanism for configuration settings (dicts/lists).
  14. It is achieved by wrapping underlying config and override __getitem__, __getattr__ methods to catch read operations
  15. and increments access counter for each property.
  16. """
  17. _access_counter: Mapping[str, int]
  18. _prefix: str # Prefix string
  19. def maybe_wrap_as_counter(self, value, key, count_usage: bool = True):
  20. """
  21. Return an attribute value optionally wrapped as access counter adapter to trace read counts.
  22. Args:
  23. value: Attribute value
  24. key: Attribute name
  25. count_usage: Whether increment usage count for given attribute. Default is True.
  26. Returns:
  27. """
  28. key_with_prefix = self._prefix + str(key)
  29. if count_usage:
  30. self._access_counter[key_with_prefix] += 1
  31. if isinstance(value, Mapping):
  32. return AccessCounterDict(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
  33. if isinstance(value, Iterable) and not isinstance(value, str):
  34. return AccessCounterList(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
  35. return value
  36. @property
  37. def access_counter(self):
  38. return self._access_counter
  39. @abc.abstractmethod
  40. def get_all_params(self) -> Set[str]:
  41. raise NotImplementedError()
  42. def get_used_params(self) -> Set[str]:
  43. used_params = {k for (k, v) in self._access_counter.items() if v > 0}
  44. return used_params
  45. def get_unused_params(self) -> Set[str]:
  46. unused_params = self.get_all_params() - self.get_used_params()
  47. return unused_params
  48. class AccessCounterDict(Mapping, AccessCounterMixin):
  49. def __init__(self, config: Mapping, access_counter: Mapping[str, int] = None, prefix: str = ""):
  50. super().__init__()
  51. self.config = config
  52. self._access_counter = access_counter or defaultdict(int)
  53. self._prefix = str(prefix)
  54. def __iter__(self):
  55. return self.config.__iter__()
  56. def __len__(self):
  57. return self.config.__len__()
  58. def __getitem__(self, item):
  59. return self.get(item)
  60. def __getattr__(self, item):
  61. value = self.config.__getitem__(item)
  62. return self.maybe_wrap_as_counter(value, item)
  63. def get(self, item, default=None):
  64. value = self.config.get(item, default)
  65. return self.maybe_wrap_as_counter(value, item)
  66. def get_all_params(self) -> Set[str]:
  67. keys = []
  68. for key, value in self.config.items():
  69. keys.append(self._prefix + str(key))
  70. value = self.maybe_wrap_as_counter(value, key, count_usage=False)
  71. if isinstance(value, AccessCounterMixin):
  72. keys += value.get_all_params()
  73. return set(keys)
  74. class AccessCounterHpmStruct(Mapping, AccessCounterMixin):
  75. def __init__(self, config: HpmStruct, access_counter: Mapping[str, int] = None, prefix: str = ""):
  76. super().__init__()
  77. self.config = config
  78. self._access_counter = access_counter or defaultdict(int)
  79. self._prefix = str(prefix)
  80. def __iter__(self):
  81. return self.config.__dict__.__iter__()
  82. def __len__(self):
  83. return self.config.__dict__.__len__()
  84. def __repr__(self):
  85. return self.config.__repr__()
  86. def __str__(self):
  87. return self.config.__str__()
  88. def __getitem__(self, item):
  89. value = self.config.__dict__[item]
  90. return self.maybe_wrap_as_counter(value, item)
  91. def __getattr__(self, item):
  92. value = self.config.__dict__[item]
  93. return self.maybe_wrap_as_counter(value, item)
  94. def get(self, item, default=None):
  95. value = self.config.__dict__.get(item, default)
  96. return self.maybe_wrap_as_counter(value, item)
  97. def get_all_params(self) -> Set[str]:
  98. keys = []
  99. for key, value in self.config.__dict__.items():
  100. # Exclude schema field from params
  101. if key == "schema":
  102. continue
  103. keys.append(self._prefix + str(key))
  104. value = self.maybe_wrap_as_counter(value, key, count_usage=False)
  105. if isinstance(value, AccessCounterMixin):
  106. keys += value.get_all_params()
  107. return set(keys)
  108. class AccessCounterList(list, AccessCounterMixin):
  109. def __init__(self, config: Iterable, access_counter: Mapping[str, int] = None, prefix: str = ""):
  110. super().__init__(config)
  111. self._access_counter = access_counter or defaultdict(int)
  112. self._prefix = str(prefix)
  113. def __iter__(self):
  114. for index, value in enumerate(super().__iter__()):
  115. yield self.maybe_wrap_as_counter(value, index)
  116. def __getitem__(self, item):
  117. value = super().__getitem__(item)
  118. return self.maybe_wrap_as_counter(value, item)
  119. def get_all_params(self) -> Set[str]:
  120. keys = []
  121. for index, value in enumerate(super().__iter__()):
  122. keys.append(self._prefix + str(index))
  123. value = self.maybe_wrap_as_counter(value, index, count_usage=False)
  124. if isinstance(value, AccessCounterMixin):
  125. keys += value.get_all_params()
  126. return set(keys)
  127. class ConfigInspector:
  128. def __init__(self, wrapped_config, unused_params_action: str):
  129. self.wrapped_config = wrapped_config
  130. self.unused_params_action = unused_params_action
  131. def __enter__(self):
  132. return self.wrapped_config
  133. def __exit__(self, exc_type, exc_val, exc_tb):
  134. unused_params = self.wrapped_config.get_unused_params()
  135. if len(unused_params):
  136. message = f"Detected unused parameters in configuration object that were not consumed by caller: {unused_params}"
  137. if self.unused_params_action == "raise":
  138. raise UnusedConfigParamException(message)
  139. elif self.unused_params_action == "warn":
  140. logger.warning(message)
  141. elif self.unused_params_action == "ignore":
  142. pass
  143. else:
  144. raise KeyError(f"Encountered unknown action key {self.unused_params_action}")
  145. def raise_if_unused_params(config: Union[HpmStruct, DictConfig, ListConfig, Mapping, list, tuple]) -> ConfigInspector:
  146. """
  147. A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
  148. this check is to ensure there were no typo or outdated configuration parameters.
  149. It at least one of config parameters was not used, this function will raise an UnusedConfigParamException exception.
  150. Example usage:
  151. >>> from super_gradients.training.utils import raise_if_unused_params
  152. >>>
  153. >>> with raise_if_unused_params(some_config) as some_config:
  154. >>> do_something_with_config(some_config)
  155. >>>
  156. :param config: A config to check
  157. :return: An instance of ConfigInspector
  158. """
  159. if isinstance(config, HpmStruct):
  160. wrapper_cls = AccessCounterHpmStruct
  161. elif isinstance(config, (Mapping, DictConfig)):
  162. wrapper_cls = AccessCounterDict
  163. elif isinstance(config, (list, tuple, ListConfig)):
  164. wrapper_cls = AccessCounterList
  165. else:
  166. raise RuntimeError(f"Unsupported type. Root configuration object must be a mapping or list. Got type {type(config)}")
  167. return ConfigInspector(wrapper_cls(config), unused_params_action="raise")
  168. def warn_if_unused_params(config):
  169. """
  170. A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
  171. this check is to ensure there were no typo or outdated configuration parameters.
  172. It at least one of config parameters was not used, this function will emit warning.
  173. Example usage:
  174. >>> from super_gradients.training.utils import warn_if_unused_params
  175. >>>
  176. >>> with warn_if_unused_params(some_config) as some_config:
  177. >>> do_something_with_config(some_config)
  178. >>>
  179. :param config: A config to check
  180. :return: An instance of ConfigInspector
  181. """
  182. if isinstance(config, HpmStruct):
  183. wrapper_cls = AccessCounterHpmStruct
  184. elif isinstance(config, (Mapping, DictConfig)):
  185. wrapper_cls = AccessCounterDict
  186. elif isinstance(config, (list, tuple, ListConfig)):
  187. wrapper_cls = AccessCounterList
  188. else:
  189. raise RuntimeError("Unsupported type. Root configuration object must be a mapping or list.")
  190. return ConfigInspector(wrapper_cls(config), unused_params_action="warn")
Discard
@@ -66,7 +66,7 @@ class HpmStruct:
             `jsonschema.exceptions.SchemaError` if the schema itselfis invalid
             `jsonschema.exceptions.SchemaError` if the schema itselfis invalid
         """
         """
         if self.schema is None:
         if self.schema is None:
-            raise AttributeError('schema was not set')
+            raise AttributeError("schema was not set")
         else:
         else:
             validate(self.__dict__, self.schema)
             validate(self.__dict__, self.schema)
 
 
@@ -89,7 +89,7 @@ class Timer:
         :param device: str
         :param device: str
             'cpu'\'cuda'
             'cpu'\'cuda'
         """
         """
-        self.on_gpu = (device == 'cuda')
+        self.on_gpu = device == "cuda"
         # On GPU time is measured using cuda.events
         # On GPU time is measured using cuda.events
         if self.on_gpu:
         if self.on_gpu:
             self.starter = torch.cuda.Event(enable_timing=True)
             self.starter = torch.cuda.Event(enable_timing=True)
@@ -141,8 +141,7 @@ class AverageMeter:
     def average(self):
     def average(self):
         if self._sum is None:
         if self._sum is None:
             return 0
             return 0
-        return ((self._sum / self._count).__float__()) if self._sum.dim() < 1 else tuple(
-            (self._sum / self._count).cpu().numpy())
+        return ((self._sum / self._count).__float__()) if self._sum.dim() < 1 else tuple((self._sum / self._count).cpu().numpy())
 
 
         # return (self._sum / self._count).__float__() if self._sum.dim() < 1 or len(self._sum) == 1 \
         # return (self._sum / self._count).__float__() if self._sum.dim() < 1 or len(self._sum) == 1 \
         #     else tuple((self._sum / self._count).cpu().numpy())
         #     else tuple((self._sum / self._count).cpu().numpy())
@@ -186,16 +185,16 @@ def get_param(params, name, default_val=None):
     :param default_val: assumed to be the same type as the value searched in the params
     :param default_val: assumed to be the same type as the value searched in the params
     :return:            the found value, or default if not found
     :return:            the found value, or default if not found
     """
     """
-    if isinstance(params, dict):
+    if isinstance(params, Mapping):
         if name in params:
         if name in params:
-            if isinstance(default_val, dict):
+            if isinstance(default_val, Mapping):
                 return {**default_val, **params[name]}
                 return {**default_val, **params[name]}
             else:
             else:
                 return params[name]
                 return params[name]
         else:
         else:
             return default_val
             return default_val
     elif hasattr(params, name):
     elif hasattr(params, name):
-        if isinstance(default_val, dict):
+        if isinstance(default_val, Mapping):
             return {**default_val, **getattr(params, name)}
             return {**default_val, **getattr(params, name)}
         else:
         else:
             return getattr(params, name)
             return getattr(params, name)
@@ -239,7 +238,7 @@ def random_seed(is_ddp, device, seed):
     :param device: 'cuda','cpu', 'cuda:<device_number>'
     :param device: 'cuda','cpu', 'cuda:<device_number>'
     :param seed: int, random seed to be set
     :param seed: int, random seed to be set
     """
     """
-    rank = 0 if not is_ddp else int(device.split(':')[1])
+    rank = 0 if not is_ddp else int(device.split(":")[1])
     torch.manual_seed(seed + rank)
     torch.manual_seed(seed + rank)
     np.random.seed(seed + rank)
     np.random.seed(seed + rank)
     random.seed(seed + rank)
     random.seed(seed + rank)
@@ -266,22 +265,21 @@ def get_filename_suffix_by_framework(framework: str):
     @param framework: (str)
     @param framework: (str)
     @return: (str) the suffix for the specific framework
     @return: (str) the suffix for the specific framework
     """
     """
-    frameworks_dict = \
-        {
-            'TENSORFLOW1': '.pb',
-            'TENSORFLOW2': '.zip',
-            'PYTORCH': '.pth',
-            'ONNX': '.onnx',
-            'TENSORRT': '.pkl',
-            'OPENVINO': '.pkl',
-            'TORCHSCRIPT': '.pth',
-            'TVM': '',
-            'KERAS': '.h5',
-            'TFLITE': '.tflite'
-        }
+    frameworks_dict = {
+        "TENSORFLOW1": ".pb",
+        "TENSORFLOW2": ".zip",
+        "PYTORCH": ".pth",
+        "ONNX": ".onnx",
+        "TENSORRT": ".pkl",
+        "OPENVINO": ".pkl",
+        "TORCHSCRIPT": ".pth",
+        "TVM": "",
+        "KERAS": ".h5",
+        "TFLITE": ".tflite",
+    }
 
 
     if framework.upper() not in frameworks_dict.keys():
     if framework.upper() not in frameworks_dict.keys():
-        raise ValueError(f'Unsupported framework: {framework}')
+        raise ValueError(f"Unsupported framework: {framework}")
 
 
     return frameworks_dict[framework.upper()]
     return frameworks_dict[framework.upper()]
 
 
@@ -294,15 +292,15 @@ def check_models_have_same_weights(model_1: torch.nn.Module, model_2: torch.nn.M
     @param model_2: Net to be checked
     @param model_2: Net to be checked
     @return: True iff the two networks have the same weights
     @return: True iff the two networks have the same weights
     """
     """
-    model_1, model_2 = model_1.to('cpu'), model_2.to('cpu')
+    model_1, model_2 = model_1.to("cpu"), model_2.to("cpu")
     models_differ = 0
     models_differ = 0
     for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
     for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
         if torch.equal(key_item_1[1], key_item_2[1]):
         if torch.equal(key_item_1[1], key_item_2[1]):
             pass
             pass
         else:
         else:
             models_differ += 1
             models_differ += 1
-            if (key_item_1[0] == key_item_2[0]):
-                print(f'Layer names match but layers have different weights for layers: {key_item_1[0]}')
+            if key_item_1[0] == key_item_2[0]:
+                print(f"Layer names match but layers have different weights for layers: {key_item_1[0]}")
     if models_differ == 0:
     if models_differ == 0:
         return True
         return True
     else:
     else:
@@ -320,7 +318,7 @@ def recursive_override(base: dict, extension: dict):
             base[k] = extension[k]
             base[k] = extension[k]
 
 
 
 
-def download_and_unzip_from_url(url, dir='.', unzip=True, delete=True):
+def download_and_unzip_from_url(url, dir=".", unzip=True, delete=True):
     """
     """
     Downloads a zip file from url to dir, and unzips it.
     Downloads a zip file from url to dir, and unzips it.
 
 
@@ -341,14 +339,14 @@ def download_and_unzip_from_url(url, dir='.', unzip=True, delete=True):
         if Path(url).is_file():  # exists in current path
         if Path(url).is_file():  # exists in current path
             Path(url).rename(f)  # move to dir
             Path(url).rename(f)  # move to dir
         elif not f.exists():
         elif not f.exists():
-            print(f'Downloading {url} to {f}...')
+            print(f"Downloading {url} to {f}...")
             torch.hub.download_url_to_file(url, f, progress=True)  # torch download
             torch.hub.download_url_to_file(url, f, progress=True)  # torch download
-        if unzip and f.suffix in ('.zip', '.gz'):
-            print(f'Unzipping {f}...')
-            if f.suffix == '.zip':
+        if unzip and f.suffix in (".zip", ".gz"):
+            print(f"Unzipping {f}...")
+            if f.suffix == ".zip":
                 ZipFile(f).extractall(path=dir)  # unzip
                 ZipFile(f).extractall(path=dir)  # unzip
-            elif f.suffix == '.gz':
-                os.system(f'tar xfz {f} --directory {f.parent}')  # unzip
+            elif f.suffix == ".gz":
+                os.system(f"tar xfz {f} --directory {f.parent}")  # unzip
             if delete:
             if delete:
                 f.unlink()  # remove zip
                 f.unlink()  # remove zip
 
 
@@ -358,7 +356,7 @@ def download_and_unzip_from_url(url, dir='.', unzip=True, delete=True):
         download_one(u, dir)
         download_one(u, dir)
 
 
 
 
-def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = '.'):
+def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = "."):
     """
     """
     Download a file from url and untar.
     Download a file from url and untar.
 
 
@@ -375,13 +373,13 @@ def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = '.'):
         if url_path.is_file():
         if url_path.is_file():
             url_path.rename(filepath)
             url_path.rename(filepath)
         elif not filepath.exists():
         elif not filepath.exists():
-            logger.info(f'Downloading {url} to {filepath}...')
+            logger.info(f"Downloading {url} to {filepath}...")
             torch.hub.download_url_to_file(url, str(filepath), progress=True)
             torch.hub.download_url_to_file(url, str(filepath), progress=True)
 
 
         modes = {".tar.gz": "r:gz", ".tar": "r:"}
         modes = {".tar.gz": "r:gz", ".tar": "r:"}
         assert filepath.suffix in modes.keys(), f"{filepath} has {filepath.suffix} suffix which is not supported"
         assert filepath.suffix in modes.keys(), f"{filepath} has {filepath.suffix} suffix which is not supported"
 
 
-        logger.info(f'Extracting to {dir}...')
+        logger.info(f"Extracting to {dir}...")
         with tarfile.open(filepath, mode=modes[filepath.suffix]) as f:
         with tarfile.open(filepath, mode=modes[filepath.suffix]) as f:
             f.extractall(dir)
             f.extractall(dir)
         filepath.unlink()
         filepath.unlink()
@@ -418,7 +416,7 @@ def get_orientation_key() -> int:
     """Get the orientation key according to PIL, which is useful to get the image size for instance
     """Get the orientation key according to PIL, which is useful to get the image size for instance
     :return: Orientation key according to PIL"""
     :return: Orientation key according to PIL"""
     for key, value in ExifTags.TAGS.items():
     for key, value in ExifTags.TAGS.items():
-        if value == 'Orientation':
+        if value == "Orientation":
             return key
             return key
 
 
 
 
@@ -442,14 +440,14 @@ def exif_size(image: Image) -> Tuple[int, int]:
             elif rotation == 8:
             elif rotation == 8:
                 image_size = (image_size[1], image_size[0])
                 image_size = (image_size[1], image_size[0])
     except Exception as ex:
     except Exception as ex:
-        print('Caught Exception trying to rotate: ' + str(image) + str(ex))
+        print("Caught Exception trying to rotate: " + str(image) + str(ex))
     width, height = image_size
     width, height = image_size
     return height, width
     return height, width
 
 
 
 
 def get_image_size_from_path(img_path: str) -> Tuple[int, int]:
 def get_image_size_from_path(img_path: str) -> Tuple[int, int]:
     """Get the image size of an image at a specific path"""
     """Get the image size of an image at a specific path"""
-    with open(img_path, 'rb') as f:
+    with open(img_path, "rb") as f:
         return exif_size(Image.open(f))
         return exif_size(Image.open(f))
 
 
 
 
Discard
@@ -47,6 +47,7 @@ from tests.unit_tests.detection_caching import TestDetectionDatasetCaching
 from tests.unit_tests.multi_scaling_test import MultiScaleTest
 from tests.unit_tests.multi_scaling_test import MultiScaleTest
 from tests.unit_tests.ppyoloe_unit_test import PPYoloETests
 from tests.unit_tests.ppyoloe_unit_test import PPYoloETests
 from tests.unit_tests.bbox_formats_test import BBoxFormatsTest
 from tests.unit_tests.bbox_formats_test import BBoxFormatsTest
+from tests.unit_tests.config_inspector_test import ConfigInspectTest
 
 
 
 
 class CoreUnitTestSuiteRunner:
 class CoreUnitTestSuiteRunner:
@@ -105,6 +106,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ResumeTrainingTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ResumeTrainingTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(CallTrainAfterTestTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(CallTrainAfterTestTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionOutputAdapter))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionOutputAdapter))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ConfigInspectTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
@@ -21,6 +21,7 @@ from tests.unit_tests.pretrained_models_unit_test import PretrainedModelsUnitTes
 from tests.unit_tests.conv_bn_relu_test import TestConvBnRelu
 from tests.unit_tests.conv_bn_relu_test import TestConvBnRelu
 from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
 from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
 from tests.unit_tests.training_params_factory_test import TrainingParamsTest
 from tests.unit_tests.training_params_factory_test import TrainingParamsTest
+from tests.unit_tests.config_inspector_test import ConfigInspectTest
 
 
 
 
 __all__ = [
 __all__ = [
@@ -46,4 +47,5 @@ __all__ = [
     "CallTrainTwiceTest",
     "CallTrainTwiceTest",
     "ResumeTrainingTest",
     "ResumeTrainingTest",
     "CallTrainAfterTestTest",
     "CallTrainAfterTestTest",
+    "ConfigInspectTest",
 ]
 ]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  1. import copy
  2. import os
  3. import unittest
  4. import pkg_resources
  5. from omegaconf import OmegaConf
  6. from super_gradients.training.models import SgModule, get_arch_params
  7. from super_gradients.training.models.model_factory import get_architecture
  8. from super_gradients.training.utils import HpmStruct
  9. from super_gradients.training.utils.config_utils import raise_if_unused_params, UnusedConfigParamException, AccessCounterDict, AccessCounterHpmStruct
  10. from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
  11. class ConfigInspectTest(unittest.TestCase):
  12. def test_inspector_raise_on_unused_args(self):
  13. def model_factory(cfg):
  14. return cfg["a"] + cfg["b"]
  15. original_config = {"unused_param": True, "a": 1, "b": 2}
  16. with self.assertRaisesRegex(UnusedConfigParamException, "Detected unused parameters in configuration object that were not consumed by caller"):
  17. config = copy.deepcopy(original_config)
  18. with raise_if_unused_params(config) as config:
  19. _ = model_factory(config)
  20. with self.assertRaisesRegex(UnusedConfigParamException, "Detected unused parameters in configuration object that were not consumed by caller"):
  21. config = OmegaConf.create(copy.deepcopy(original_config))
  22. with raise_if_unused_params(config) as config:
  23. _ = model_factory(config)
  24. with self.assertRaisesRegex(UnusedConfigParamException, "Detected unused parameters in configuration object that were not consumed by caller"):
  25. config = HpmStruct(**copy.deepcopy(original_config))
  26. with raise_if_unused_params(copy.deepcopy(config)) as config:
  27. _ = model_factory(config)
  28. def test_inspector_raise_on_unused_args_with_modification_of_the_config(self):
  29. def model_factory(cfg):
  30. cfg["this_is_a_test_property_that_is_set_but_not_used"] = 42
  31. cfg["this_is_a_test_property_that_is_set_and_used"] = 39
  32. return cfg["a"] + cfg["b"] + cfg["this_is_a_test_property_that_is_set_and_used"]
  33. original_config = {"unused_param": True, "a": 1, "b": 2}
  34. with self.assertRaisesRegex(UnusedConfigParamException, "Detected unused parameters in configuration object that were not consumed by caller"):
  35. config = copy.deepcopy(original_config)
  36. with raise_if_unused_params(config) as config:
  37. result = model_factory(config)
  38. self.assertEqual(result, 42)
  39. self.assertTrue("this_is_a_test_property_that_is_set_and_used" in config.get_used_params())
  40. with self.assertRaisesRegex(UnusedConfigParamException, "Detected unused parameters in configuration object that were not consumed by caller"):
  41. config = OmegaConf.create(copy.deepcopy(original_config))
  42. with raise_if_unused_params(config) as config:
  43. result = model_factory(config)
  44. self.assertEqual(result, 42)
  45. self.assertTrue("this_is_a_test_property_that_is_set_and_used" in config.get_used_params())
  46. with self.assertRaisesRegex(UnusedConfigParamException, "Detected unused parameters in configuration object that were not consumed by caller"):
  47. config = HpmStruct(**copy.deepcopy(original_config))
  48. with raise_if_unused_params(copy.deepcopy(config)) as config:
  49. result = model_factory(config)
  50. self.assertEqual(result, 42)
  51. self.assertTrue("this_is_a_test_property_that_is_set_and_used" in config.get_used_params())
  52. def test_inspector_with_dict_and_list(self):
  53. config = {
  54. "beta": 0.73,
  55. "lr": 1e-4,
  56. "encoder": {
  57. "indexes": [1, 2, 3],
  58. "pretrained": True,
  59. "backbone": "yolov3",
  60. "layers": [
  61. {"blocks": 4},
  62. {"blocks": 3},
  63. {"blocks": 6},
  64. {"blocks": 9},
  65. ],
  66. },
  67. }
  68. c = AccessCounterDict(config)
  69. # Simulate parameters usage
  70. print(c["beta"])
  71. print(c["encoder"]["layers"])
  72. print(sum(c["encoder"]["indexes"]))
  73. print(c["beta"])
  74. print(c["encoder"]["layers"][0])
  75. print(c["encoder"]["layers"][3]["blocks"])
  76. print("All parameters")
  77. print(c.get_all_params())
  78. print("Unused parameters")
  79. print(c.get_unused_params())
  80. self.assertSetEqual(
  81. c.get_unused_params(),
  82. {
  83. "lr",
  84. "encoder.pretrained",
  85. "encoder.backbone",
  86. "encoder.layers.0.blocks",
  87. "encoder.layers.1",
  88. "encoder.layers.1.blocks",
  89. "encoder.layers.2",
  90. "encoder.layers.2.blocks",
  91. "encoder.layers.2.blocks",
  92. },
  93. )
  94. def test_inspector_with_omegaconf(self):
  95. config = {
  96. "beta": 0.73,
  97. "lr": 1e-4,
  98. "encoder": {
  99. "indexes": [1, 2, 3],
  100. "pretrained": True,
  101. "backbone": "yolov3",
  102. "layers": [
  103. {"blocks": 4},
  104. {"blocks": 3},
  105. {"blocks": 6},
  106. {"blocks": 9},
  107. ],
  108. },
  109. }
  110. config = OmegaConf.create(config)
  111. c = AccessCounterDict(config)
  112. # Simulate parameters usage
  113. print(c.beta)
  114. print(c.encoder.layers)
  115. print(sum(c.encoder.indexes))
  116. print(c.encoder.layers[0])
  117. print(c.encoder.layers[3].blocks)
  118. print("All parameters")
  119. print(c.get_all_params())
  120. print("Unused parameters")
  121. print(c.get_unused_params())
  122. self.assertSetEqual(
  123. c.get_unused_params(),
  124. {
  125. "lr",
  126. "encoder.pretrained",
  127. "encoder.backbone",
  128. "encoder.layers.0.blocks",
  129. "encoder.layers.1",
  130. "encoder.layers.1.blocks",
  131. "encoder.layers.2",
  132. "encoder.layers.2.blocks",
  133. "encoder.layers.2.blocks",
  134. },
  135. )
  136. def test_inspector_with_hpm_struct(self):
  137. config = {
  138. "beta": 0.73,
  139. "lr": 1e-4,
  140. "encoder": {
  141. "indexes": [1, 2, 3],
  142. "pretrained": True,
  143. "backbone": "yolov3",
  144. "layers": [
  145. {"blocks": 4},
  146. {"blocks": 3},
  147. {"blocks": 6},
  148. {"blocks": 9},
  149. ],
  150. },
  151. }
  152. config = HpmStruct(**config)
  153. c = AccessCounterHpmStruct(config)
  154. # Simulate parameters usage
  155. print(c.beta)
  156. print(c.encoder.layers)
  157. print(sum(c.encoder.indexes))
  158. print(c.encoder.layers[0])
  159. print(c.encoder.layers[3].blocks)
  160. print("All parameters")
  161. print(c.get_all_params())
  162. print("Unused parameters")
  163. print(c.get_unused_params())
  164. self.assertSetEqual(
  165. c.get_unused_params(),
  166. {
  167. "lr",
  168. "encoder.pretrained",
  169. "encoder.backbone",
  170. "encoder.layers.0.blocks",
  171. "encoder.layers.1",
  172. "encoder.layers.1.blocks",
  173. "encoder.layers.2",
  174. "encoder.layers.2.blocks",
  175. "encoder.layers.2.blocks",
  176. },
  177. )
  178. def get_all_arch_params_configs(self):
  179. config_path = pkg_resources.resource_filename("super_gradients.recipes", "arch_params")
  180. configs = [path.replace(".yaml", "") for path in sorted(os.listdir(config_path)) if path.endswith(".yaml")]
  181. return configs
  182. def test_resnet18_cifar_arch_params(self):
  183. arch_params = get_arch_params("resnet18_cifar_arch_params")
  184. architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture("resnet18", HpmStruct(**arch_params))
  185. with raise_if_unused_params(arch_params) as tracked_arch_params:
  186. _ = architecture_cls(arch_params=tracked_arch_params)
  187. with self.assertRaisesRegex(UnusedConfigParamException, "Detected unused parameters in configuration object that were not consumed by caller"):
  188. arch_params.override(me_is_not_used=True)
  189. with raise_if_unused_params(arch_params) as tracked_arch_params:
  190. _ = architecture_cls(arch_params=tracked_arch_params)
  191. @unittest.expectedFailure
  192. def test_model_from_arch_params(self):
  193. all_configs = self.get_all_arch_params_configs()
  194. for config_name in all_configs:
  195. with self.subTest(config_name):
  196. model_name = config_name.replace("_arch_params", "")
  197. arch_params = get_arch_params(config_name)
  198. architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture(model_name, HpmStruct(**arch_params))
  199. self.assertIsNotNone(arch_params, msg=model_name)
  200. if not issubclass(architecture_cls, SgModule):
  201. # This instantiation method is not supported as unpacking arch_params would cause root params to be considered "used"
  202. # net = architecture_cls(**arch_params.to_dict(include_schema=False))
  203. self.skipTest("Skipping test since model class is not subclass of SgModule")
  204. else:
  205. # Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take **kwargs instead
  206. if "arch_params" not in get_callable_param_names(architecture_cls):
  207. self.skipTest("Skipping test since model c'tor does not receive arch_params argument")
  208. # This instantiation method is not supported as unpacking arch_params would cause root params to be considered "used"
  209. # net = architecture_cls(**arch_params.to_dict(include_schema=False))
  210. pass
  211. else:
  212. try:
  213. _ = architecture_cls(arch_params=arch_params)
  214. except Exception as e:
  215. self.skipTest(f"Skipping test since model cannot be instantiated at all {e}")
  216. with raise_if_unused_params(arch_params) as tracked_arch_params:
  217. _ = architecture_cls(arch_params=tracked_arch_params)
Discard