|
@@ -310,21 +310,6 @@ class Trainer:
|
|
}
|
|
}
|
|
self.dataset_params = HpmStruct(**self.dataset_params)
|
|
self.dataset_params = HpmStruct(**self.dataset_params)
|
|
|
|
|
|
- def _set_ckpt_loading_attributes(self):
|
|
|
|
- """
|
|
|
|
- Sets checkpoint loading related attributes according to self.checkpoint_params
|
|
|
|
- """
|
|
|
|
- self.checkpoint = {}
|
|
|
|
- self.strict_load = core_utils.get_param(self.checkpoint_params, "strict_load", default_val=StrictLoad.ON)
|
|
|
|
- self.load_ema_as_net = core_utils.get_param(self.checkpoint_params, "load_ema_as_net", default_val=False)
|
|
|
|
- self.source_ckpt_folder_name = core_utils.get_param(self.checkpoint_params, "source_ckpt_folder_name")
|
|
|
|
- self.load_checkpoint = core_utils.get_param(self.checkpoint_params, "load_checkpoint", default_val=False)
|
|
|
|
- self.load_backbone = core_utils.get_param(self.checkpoint_params, "load_backbone", default_val=False)
|
|
|
|
- self.external_checkpoint_path = core_utils.get_param(self.checkpoint_params, "external_checkpoint_path")
|
|
|
|
- if self.load_checkpoint or self.external_checkpoint_path:
|
|
|
|
- self.load_weights_only = core_utils.get_param(self.checkpoint_params, "load_weights_only", default_val=False)
|
|
|
|
- self.ckpt_name = core_utils.get_param(self.checkpoint_params, "ckpt_name", default_val=self.ckpt_name)
|
|
|
|
-
|
|
|
|
def _net_to_device(self):
|
|
def _net_to_device(self):
|
|
"""
|
|
"""
|
|
Manipulates self.net according to self.multi_gpu
|
|
Manipulates self.net according to self.multi_gpu
|
|
@@ -581,6 +566,7 @@ class Trainer:
|
|
self.load_ema_as_net = False
|
|
self.load_ema_as_net = False
|
|
self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
|
|
self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
|
|
self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
|
|
self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
|
|
|
|
+ self.load_checkpoint = self.load_checkpoint or self.external_checkpoint_path is not None
|
|
self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
|
|
self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
|
|
self._load_checkpoint_to_model()
|
|
self._load_checkpoint_to_model()
|
|
|
|
|
|
@@ -1448,8 +1434,6 @@ class Trainer:
|
|
is True
|
|
is True
|
|
"""
|
|
"""
|
|
|
|
|
|
- self._set_ckpt_loading_attributes()
|
|
|
|
-
|
|
|
|
if self.load_checkpoint or self.external_checkpoint_path:
|
|
if self.load_checkpoint or self.external_checkpoint_path:
|
|
# GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
|
|
# GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
|
|
ckpt_local_path = get_ckpt_local_path(
|
|
ckpt_local_path = get_ckpt_local_path(
|