1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
|
- import os
- import tempfile
- import pkg_resources
- import torch
- from super_gradients.common import explicit_params_validation, ADNNModelRepositoryDataInterfaces
- from super_gradients.training.pretrained_models import MODEL_URLS
- try:
- from torch.hub import download_url_to_file, load_state_dict_from_url
- except (ModuleNotFoundError, ImportError, NameError):
- from torch.hub import _download_url_to_file as download_url_to_file
- def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
- """
- Gets the local path to the checkpoint file, which will be:
- - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
- - if the checkpoint file is remotely located:
- when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,
- otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite
- YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.
- - external_checkpoint_path when external_checkpoint_path != None
- @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
- @param experiment_name: experiment name attr in trainer
- @param ckpt_name: checkpoint filename
- @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
- @return:
- """
- source_ckpt_folder_name = source_ckpt_folder_name or experiment_name
- ckpt_local_path = external_checkpoint_path or pkg_resources.resource_filename('checkpoints', source_ckpt_folder_name + os.path.sep + ckpt_name)
- return ckpt_local_path
- def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: str):
- """
- Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
- @param net: (nn.Module) to load state_dict to
- @param state_dict: (dict) Chekpoint state_dict
- @param strict: (str) key matching strictness
- @return:
- """
- try:
- net.load_state_dict(state_dict['net'] if 'net' in state_dict.keys() else state_dict, strict=strict)
- except (RuntimeError, ValueError, KeyError) as ex:
- if strict == 'no_key_matching':
- adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict)
- net.load_state_dict(adapted_state_dict['net'], strict=True)
- else:
- raise_informative_runtime_error(net.state_dict(), state_dict, ex)
- @explicit_params_validation(validation_type='None')
- def copy_ckpt_to_local_folder(local_ckpt_destination_dir: str, ckpt_filename: str, remote_ckpt_source_dir: str = None,
- path_src: str = 'local', overwrite_local_ckpt: bool = False,
- load_weights_only: bool = False):
- """
- Copy the checkpoint from any supported source to a local destination path
- :param local_ckpt_destination_dir: destination where the checkpoint will be saved to
- :param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth
- :param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model\full URL)
- :param path_src: S3 / url
- :param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder
- :return: Path to checkpoint
- """
- ckpt_file_full_local_path = None
- # IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
- remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir
- if not overwrite_local_ckpt:
- # CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
- download_ckpt_destination_dir = tempfile.gettempdir()
- print('PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False '
- '-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART')
- else:
- # SAVE THE CHECKPOINT TO MODEL's FOLDER
- download_ckpt_destination_dir = pkg_resources.resource_filename('checkpoints', local_ckpt_destination_dir)
- if path_src.startswith('s3'):
- model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
- # DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
- ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
- ckpt_source_remote_dir=remote_ckpt_source_dir,
- ckpt_destination_local_dir=download_ckpt_destination_dir,
- ckpt_file_name=ckpt_filename,
- overwrite_local_checkpoints_file=overwrite_local_ckpt)
- if not load_weights_only:
- # COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
- model_checkpoints_data_interface.load_all_remote_log_files(model_name=remote_ckpt_source_dir,
- model_checkpoint_local_dir=download_ckpt_destination_dir)
- if path_src == 'url':
- ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
- # DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
- download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)
- return ckpt_file_full_local_path
- def read_ckpt_state_dict(ckpt_path: str, device="cpu"):
- if not os.path.exists(ckpt_path):
- raise ValueError('Incorrect Checkpoint path')
- if device == "cuda":
- state_dict = torch.load(ckpt_path)
- else:
- state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
- return state_dict
- def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict,
- exclude: list = [], solver: callable = None):
- """
- Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
- the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
- :param model_state_dict: the model state_dict
- :param source_ckpt: checkpoint dict
- :param exclude optional list for excluded layers
- :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
- that returns a desired weight for ckpt_val.
- :return: renamed checkpoint dict (if possible)
- """
- if 'net' in source_ckpt.keys():
- source_ckpt = source_ckpt['net']
- model_state_dict_excluded = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
- new_ckpt_dict = {}
- for (ckpt_key, ckpt_val), (model_key, model_val) in zip(source_ckpt.items(), model_state_dict_excluded.items()):
- if solver is not None:
- ckpt_val = solver(ckpt_key, ckpt_val, model_key, model_val)
- if ckpt_val.shape != model_val.shape:
- raise ValueError(f'ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}'
- f' with shape {model_val.shape} in the model')
- new_ckpt_dict[model_key] = ckpt_val
- return {'net': new_ckpt_dict}
- def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
- """
- Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
- and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
- """
- try:
- new_ckpt_dict = adapt_state_dict_to_fit_model_layer_names(state_dict, checkpoint)
- temp_file = tempfile.NamedTemporaryFile().name + '.pt'
- torch.save(new_ckpt_dict, temp_file)
- exception_msg = f"\n{'=' * 200}\n{str(exception_msg)} \nconvert ckpt via the utils.adapt_state_dict_to_fit_" \
- f"model_layer_names method\na converted checkpoint file was saved in the path {temp_file}\n{'=' * 200}"
- except ValueError as ex: # IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL
- exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
- finally:
- raise RuntimeError(exception_msg)
- def load_checkpoint_to_model(ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str,
- load_weights_only: bool, load_ema_as_net: bool = False):
- """
- Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
- @param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set
- @param ckpt_local_path: local path to the checkpoint file
- @param load_backbone: whether to load the checkpoint as a backbone
- @param net: network to load the checkpoint to
- @param strict:
- @param load_weights_only:
- @return:
- """
- if ckpt_local_path is None or not os.path.exists(ckpt_local_path):
- error_msg = 'Error - loading Model Checkpoint: Path {} does not exist'.format(ckpt_local_path)
- raise RuntimeError(error_msg)
- if load_backbone and not hasattr(net.module, 'backbone'):
- raise ValueError("No backbone attribute in net - Can't load backbone weights")
- # LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT
- checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path)
- if load_ema_as_net:
- if 'ema_net' not in checkpoint.keys():
- raise ValueError("Can't load ema network- no EMA network stored in checkpoint file")
- else:
- checkpoint['net'] = checkpoint['ema_net']
- # LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL
- if load_backbone:
- adaptive_load_state_dict(net.module.backbone, checkpoint, strict)
- else:
- adaptive_load_state_dict(net, checkpoint, strict)
- if load_weights_only or load_backbone:
- # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
- [checkpoint.pop(key) for key in list(checkpoint.keys()) if key != 'net']
- return checkpoint
- class MissingPretrainedWeightsException(Exception):
- """Exception raised by unsupported pretrianed model.
- Attributes:
- message -- explanation of the error
- """
- def __init__(self, desc):
- self.message = "Missing pretrained wights: " + desc
- super().__init__(self.message)
- def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
- """
- Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
- """
- if ckpt_val.shape != model_val.shape and ckpt_key == 'module._backbone._modules_list.0.conv.conv.weight' and \
- model_key == '_backbone._modules_list.0.conv.weight':
- model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
- model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
- model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
- model_val.data[:, :, 1::2, 1::2] = ckpt_val.data[:, 9:12]
- replacement = model_val
- else:
- replacement = ckpt_val
- return replacement
- def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
- """
- Loads pretrained weights from the MODEL_URLS dictionary to model
- @param architecture: name of the model's architecture
- @param model: model to load pretrinaed weights for
- @param pretrained_weights: name for the pretrianed weights (i.e imagenet)
- @return: None
- """
- model_url_key = architecture + '_' + str(pretrained_weights)
- if model_url_key not in MODEL_URLS.keys():
- raise MissingPretrainedWeightsException(model_url_key)
- url = MODEL_URLS[model_url_key]
- unique_filename = url.split("https://deci-pretrained-models.s3.amazonaws.com/")[1].replace('/', '_').replace(' ', '_')
- map_location = torch.device('cpu')
- pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
- _load_weights(architecture, model, pretrained_state_dict)
- def _load_weights(architecture, model, pretrained_state_dict):
- if 'ema_net' in pretrained_state_dict.keys():
- pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
- solver = _yolox_ckpt_solver if "yolox" in architecture else None
- adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(model_state_dict=model.state_dict(),
- source_ckpt=pretrained_state_dict,
- solver=solver)
- model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
- def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
- """
- Loads pretrained weights from the MODEL_URLS dictionary to model
- @param architecture: name of the model's architecture
- @param model: model to load pretrinaed weights for
- @param pretrained_weights: path tp pretrained weights
- @return: None
- """
- map_location = torch.device('cpu')
- pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
- _load_weights(architecture, model, pretrained_state_dict)
|