Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#468 Bug/sg 399 external checkpoints fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-399_external_checkpoints_fix
1 changed files with 1 additions and 17 deletions
  1. 1
    17
      src/super_gradients/training/sg_trainer/sg_trainer.py
@@ -310,21 +310,6 @@ class Trainer:
         }
         }
         self.dataset_params = HpmStruct(**self.dataset_params)
         self.dataset_params = HpmStruct(**self.dataset_params)
 
 
-    def _set_ckpt_loading_attributes(self):
-        """
-        Sets checkpoint loading related attributes according to self.checkpoint_params
-        """
-        self.checkpoint = {}
-        self.strict_load = core_utils.get_param(self.checkpoint_params, "strict_load", default_val=StrictLoad.ON)
-        self.load_ema_as_net = core_utils.get_param(self.checkpoint_params, "load_ema_as_net", default_val=False)
-        self.source_ckpt_folder_name = core_utils.get_param(self.checkpoint_params, "source_ckpt_folder_name")
-        self.load_checkpoint = core_utils.get_param(self.checkpoint_params, "load_checkpoint", default_val=False)
-        self.load_backbone = core_utils.get_param(self.checkpoint_params, "load_backbone", default_val=False)
-        self.external_checkpoint_path = core_utils.get_param(self.checkpoint_params, "external_checkpoint_path")
-        if self.load_checkpoint or self.external_checkpoint_path:
-            self.load_weights_only = core_utils.get_param(self.checkpoint_params, "load_weights_only", default_val=False)
-        self.ckpt_name = core_utils.get_param(self.checkpoint_params, "ckpt_name", default_val=self.ckpt_name)
-
     def _net_to_device(self):
     def _net_to_device(self):
         """
         """
         Manipulates self.net according to self.multi_gpu
         Manipulates self.net according to self.multi_gpu
@@ -581,6 +566,7 @@ class Trainer:
         self.load_ema_as_net = False
         self.load_ema_as_net = False
         self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
         self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
         self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
         self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
+        self.load_checkpoint = self.load_checkpoint or self.external_checkpoint_path is not None
         self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
         self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
         self._load_checkpoint_to_model()
         self._load_checkpoint_to_model()
 
 
@@ -1448,8 +1434,6 @@ class Trainer:
          is True
          is True
         """
         """
 
 
-        self._set_ckpt_loading_attributes()
-
         if self.load_checkpoint or self.external_checkpoint_path:
         if self.load_checkpoint or self.external_checkpoint_path:
             # GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
             # GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
             ckpt_local_path = get_ckpt_local_path(
             ckpt_local_path = get_ckpt_local_path(
Discard