Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

checkpoint_utils.py 13 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
  1. import os
  2. import tempfile
  3. import pkg_resources
  4. import torch
  5. from super_gradients.common import explicit_params_validation, ADNNModelRepositoryDataInterfaces
  6. from super_gradients.training.pretrained_models import MODEL_URLS
  7. try:
  8. from torch.hub import download_url_to_file, load_state_dict_from_url
  9. except (ModuleNotFoundError, ImportError, NameError):
  10. from torch.hub import _download_url_to_file as download_url_to_file
  11. def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
  12. """
  13. Gets the local path to the checkpoint file, which will be:
  14. - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
  15. - if the checkpoint file is remotely located:
  16. when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,
  17. otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite
  18. YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.
  19. - external_checkpoint_path when external_checkpoint_path != None
  20. @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
  21. @param experiment_name: experiment name attr in trainer
  22. @param ckpt_name: checkpoint filename
  23. @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
  24. @return:
  25. """
  26. source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
  27. ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
  28. return ckpt_local_path
  29. def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: str):
  30. """
  31. Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
  32. @param net: (nn.Module) to load state_dict to
  33. @param state_dict: (dict) Chekpoint state_dict
  34. @param strict: (str) key matching strictness
  35. @return:
  36. """
  37. try:
  38. net.load_state_dict(state_dict['net'] if 'net' in state_dict.keys() else state_dict, strict=strict)
  39. except (RuntimeError, ValueError, KeyError) as ex:
  40. if strict == 'no_key_matching':
  41. adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict)
  42. net.load_state_dict(adapted_state_dict['net'], strict=True)
  43. else:
  44. raise_informative_runtime_error(net.state_dict(), state_dict, ex)
  45. @explicit_params_validation(validation_type='None')
  46. def copy_ckpt_to_local_folder(local_ckpt_destination_dir: str, ckpt_filename: str, remote_ckpt_source_dir: str = None,
  47. path_src: str = 'local', overwrite_local_ckpt: bool = False,
  48. load_weights_only: bool = False):
  49. """
  50. Copy the checkpoint from any supported source to a local destination path
  51. :param local_ckpt_destination_dir: destination where the checkpoint will be saved to
  52. :param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth
  53. :param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model\full URL)
  54. :param path_src: S3 / url
  55. :param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder
  56. :return: Path to checkpoint
  57. """
  58. ckpt_file_full_local_path = None
  59. # IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
  60. remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir
  61. if not overwrite_local_ckpt:
  62. # CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
  63. download_ckpt_destination_dir = tempfile.gettempdir()
  64. print('PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False '
  65. '-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART')
  66. else:
  67. # SAVE THE CHECKPOINT TO MODEL's FOLDER
  68. download_ckpt_destination_dir = pkg_resources.resource_filename('checkpoints', local_ckpt_destination_dir)
  69. if path_src.startswith('s3'):
  70. model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
  71. # DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
  72. ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
  73. ckpt_source_remote_dir=remote_ckpt_source_dir,
  74. ckpt_destination_local_dir=download_ckpt_destination_dir,
  75. ckpt_file_name=ckpt_filename,
  76. overwrite_local_checkpoints_file=overwrite_local_ckpt)
  77. if not load_weights_only:
  78. # COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
  79. model_checkpoints_data_interface.load_all_remote_log_files(model_name=remote_ckpt_source_dir,
  80. model_checkpoint_local_dir=download_ckpt_destination_dir)
  81. if path_src == 'url':
  82. ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
  83. # DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
  84. download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)
  85. return ckpt_file_full_local_path
  86. def read_ckpt_state_dict(ckpt_path: str, device="cpu"):
  87. if not os.path.exists(ckpt_path):
  88. raise ValueError('Incorrect Checkpoint path')
  89. if device == "cuda":
  90. state_dict = torch.load(ckpt_path)
  91. else:
  92. state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
  93. return state_dict
  94. def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict,
  95. exclude: list = [], solver: callable = None):
  96. """
  97. Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
  98. the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
  99. :param model_state_dict: the model state_dict
  100. :param source_ckpt: checkpoint dict
  101. :param exclude optional list for excluded layers
  102. :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
  103. that returns a desired weight for ckpt_val.
  104. :return: renamed checkpoint dict (if possible)
  105. """
  106. if 'net' in source_ckpt.keys():
  107. source_ckpt = source_ckpt['net']
  108. model_state_dict_excluded = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
  109. new_ckpt_dict = {}
  110. for (ckpt_key, ckpt_val), (model_key, model_val) in zip(source_ckpt.items(), model_state_dict_excluded.items()):
  111. if solver is not None:
  112. ckpt_val = solver(ckpt_key, ckpt_val, model_key, model_val)
  113. if ckpt_val.shape != model_val.shape:
  114. raise ValueError(f'ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}'
  115. f' with shape {model_val.shape} in the model')
  116. new_ckpt_dict[model_key] = ckpt_val
  117. return {'net': new_ckpt_dict}
  118. def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
  119. """
  120. Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
  121. and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
  122. """
  123. try:
  124. new_ckpt_dict = adapt_state_dict_to_fit_model_layer_names(state_dict, checkpoint)
  125. temp_file = tempfile.NamedTemporaryFile().name + '.pt'
  126. torch.save(new_ckpt_dict, temp_file)
  127. exception_msg = f"\n{'=' * 200}\n{str(exception_msg)} \nconvert ckpt via the utils.adapt_state_dict_to_fit_" \
  128. f"model_layer_names method\na converted checkpoint file was saved in the path {temp_file}\n{'=' * 200}"
  129. except ValueError as ex: # IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL
  130. exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
  131. finally:
  132. raise RuntimeError(exception_msg)
  133. def load_checkpoint_to_model(ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str,
  134. load_weights_only: bool, load_ema_as_net: bool = False):
  135. """
  136. Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
  137. @param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set
  138. @param ckpt_local_path: local path to the checkpoint file
  139. @param load_backbone: whether to load the checkpoint as a backbone
  140. @param net: network to load the checkpoint to
  141. @param strict:
  142. @param load_weights_only:
  143. @return:
  144. """
  145. if ckpt_local_path is None or not os.path.exists(ckpt_local_path):
  146. error_msg = 'Error - loading Model Checkpoint: Path {} does not exist'.format(ckpt_local_path)
  147. raise RuntimeError(error_msg)
  148. if load_backbone and not hasattr(net.module, 'backbone'):
  149. raise ValueError("No backbone attribute in net - Can't load backbone weights")
  150. # LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT
  151. checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path)
  152. if load_ema_as_net:
  153. if 'ema_net' not in checkpoint.keys():
  154. raise ValueError("Can't load ema network- no EMA network stored in checkpoint file")
  155. else:
  156. checkpoint['net'] = checkpoint['ema_net']
  157. # LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL
  158. if load_backbone:
  159. adaptive_load_state_dict(net.module.backbone, checkpoint, strict)
  160. else:
  161. adaptive_load_state_dict(net, checkpoint, strict)
  162. if load_weights_only or load_backbone:
  163. # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
  164. [checkpoint.pop(key) for key in list(checkpoint.keys()) if key != 'net']
  165. return checkpoint
  166. class MissingPretrainedWeightsException(Exception):
  167. """Exception raised by unsupported pretrianed model.
  168. Attributes:
  169. message -- explanation of the error
  170. """
  171. def __init__(self, desc):
  172. self.message = "Missing pretrained wights: " + desc
  173. super().__init__(self.message)
  174. def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
  175. """
  176. Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
  177. """
  178. if ckpt_val.shape != model_val.shape and ckpt_key == 'module._backbone._modules_list.0.conv.conv.weight' and \
  179. model_key == '_backbone._modules_list.0.conv.weight':
  180. model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
  181. model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
  182. model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
  183. model_val.data[:, :, 1::2, 1::2] = ckpt_val.data[:, 9:12]
  184. replacement = model_val
  185. else:
  186. replacement = ckpt_val
  187. return replacement
  188. def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
  189. """
  190. Loads pretrained weights from the MODEL_URLS dictionary to model
  191. @param architecture: name of the model's architecture
  192. @param model: model to load pretrinaed weights for
  193. @param pretrained_weights: name for the pretrianed weights (i.e imagenet)
  194. @return: None
  195. """
  196. model_url_key = architecture + '_' + str(pretrained_weights)
  197. if model_url_key not in MODEL_URLS.keys():
  198. raise MissingPretrainedWeightsException(model_url_key)
  199. url = MODEL_URLS[model_url_key]
  200. unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace('/', '_').replace(' ', '_')
  201. map_location = torch.device('cpu')
  202. pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
  203. _load_weights(architecture, model, pretrained_state_dict)
  204. def _load_weights(architecture, model, pretrained_state_dict):
  205. if 'ema_net' in pretrained_state_dict.keys():
  206. pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
  207. solver = _yolox_ckpt_solver if "yolox" in architecture else None
  208. adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(model_state_dict=model.state_dict(),
  209. source_ckpt=pretrained_state_dict,
  210. solver=solver)
  211. model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
  212. def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
  213. """
  214. Loads pretrained weights from the MODEL_URLS dictionary to model
  215. @param architecture: name of the model's architecture
  216. @param model: model to load pretrinaed weights for
  217. @param pretrained_weights: path tp pretrained weights
  218. @return: None
  219. """
  220. map_location = torch.device('cpu')
  221. pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
  222. _load_weights(architecture, model, pretrained_state_dict)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...