|
@@ -10,7 +10,8 @@ except (ModuleNotFoundError, ImportError, NameError):
|
|
from torch.hub import _download_url_to_file as download_url_to_file
|
|
from torch.hub import _download_url_to_file as download_url_to_file
|
|
|
|
|
|
|
|
|
|
-def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str, overwrite_local_checkpoint: bool, load_weights_only: bool):
|
|
|
|
|
|
+def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str,
|
|
|
|
+ overwrite_local_checkpoint: bool, load_weights_only: bool):
|
|
"""
|
|
"""
|
|
Gets the local path to the checkpoint file, which will be:
|
|
Gets the local path to the checkpoint file, which will be:
|
|
- By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
|
|
- By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
|
|
@@ -132,7 +133,8 @@ def read_ckpt_state_dict(ckpt_path: str, device="cpu"):
|
|
return state_dict
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
-def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: callable=None):
|
|
|
|
|
|
+def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict,
|
|
|
|
+ exclude: list = [], solver: callable = None):
|
|
"""
|
|
"""
|
|
Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
|
|
Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
|
|
the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
|
|
the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
|
|
@@ -174,7 +176,8 @@ def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
|
|
raise RuntimeError(exception_msg)
|
|
raise RuntimeError(exception_msg)
|
|
|
|
|
|
|
|
|
|
-def load_checkpoint_to_model(ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str, load_weights_only: bool, load_ema_as_net: bool = False):
|
|
|
|
|
|
+def load_checkpoint_to_model(ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str,
|
|
|
|
+ load_weights_only: bool, load_ema_as_net: bool = False):
|
|
"""
|
|
"""
|
|
Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
|
|
Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
|
|
|
|
|
|
@@ -226,12 +229,14 @@ class MissingPretrainedWeightsException(Exception):
|
|
self.message = "Missing pretrained wights: " + desc
|
|
self.message = "Missing pretrained wights: " + desc
|
|
super().__init__(self.message)
|
|
super().__init__(self.message)
|
|
|
|
|
|
|
|
+
|
|
def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
|
|
def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
|
|
"""
|
|
"""
|
|
Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
|
|
Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
|
|
"""
|
|
"""
|
|
|
|
|
|
- if ckpt_val.shape != model_val.shape and ckpt_key == 'module._backbone._modules_list.0.conv.conv.weight' and model_key == '_backbone._modules_list.0.conv.weight':
|
|
|
|
|
|
+ if ckpt_val.shape != model_val.shape and ckpt_key == 'module._backbone._modules_list.0.conv.conv.weight' and \
|
|
|
|
+ model_key == '_backbone._modules_list.0.conv.weight':
|
|
model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
|
|
model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
|
|
model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
|
|
model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
|
|
model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
|
|
model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
|
|
@@ -242,6 +247,7 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
|
|
|
|
|
|
return replacement
|
|
return replacement
|
|
|
|
|
|
|
|
+
|
|
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
|
|
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
|
|
|
|
|
|
"""
|
|
"""
|
|
@@ -262,5 +268,7 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
|
|
if 'ema_net' in pretrained_state_dict.keys():
|
|
if 'ema_net' in pretrained_state_dict.keys():
|
|
pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
|
|
pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
|
|
solver = _yolox_ckpt_solver if "yolox" in architecture else None
|
|
solver = _yolox_ckpt_solver if "yolox" in architecture else None
|
|
- adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(model_state_dict=model.state_dict(), source_ckpt=pretrained_state_dict, solver=solver)
|
|
|
|
|
|
+ adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(model_state_dict=model.state_dict(),
|
|
|
|
+ source_ckpt=pretrained_state_dict,
|
|
|
|
+ solver=solver)
|
|
model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
|
|
model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
|