Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#948 Bug/sg 764 wrong ckpt when resuming with external ckpt root dir

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-764_wrong_ckpt_when_resuming_with_external_ckpt_root_dir
@@ -58,25 +58,29 @@ def get_checkpoints_dir_path(experiment_name: str, ckpt_root_dir: str = None) ->
     return os.path.join(ckpt_root_dir, experiment_name)
     return os.path.join(ckpt_root_dir, experiment_name)
 
 
 
 
-def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, external_checkpoint_path: str):
+def get_ckpt_local_path(experiment_name: str, ckpt_name: str, external_checkpoint_path: str, ckpt_root_dir: str = None) -> str:
     """
     """
     Gets the local path to the checkpoint file, which will be:
     Gets the local path to the checkpoint file, which will be:
-        - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
+        - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name.
+        - external_checkpoint_path when external_checkpoint_path != None
+        - ckpt_root_dir/experiment_name/ckpt_name when ckpt_root_dir != None.
         - if the checkpoint file is remotely located:
         - if the checkpoint file is remotely located:
             when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,
             when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,
             otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite
             otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite
             YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.
             YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.
-        - external_checkpoint_path when external_checkpoint_path != None
 
 
-    :param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.
-    :param experiment_name: experiment name attr in trainer
-    :param ckpt_name: checkpoint filename
-    :param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)
-    :return:
+
+    :param experiment_name: experiment name attr in trainer :param ckpt_name: checkpoint filename
+    :param external_checkpoint_path: full path to checkpoint file (that might be located outside of
+    super_gradients/checkpoints directory)
+    :param ckpt_root_dir: Local root directory path where all experiment
+     logging directories will reside. When None, it is assumed that pkg_resources.resource_filename(
+    'checkpoints', "") exists and will be used.
+
+     :return: local path of the checkpoint file (Str)
     """
     """
     if external_checkpoint_path:
     if external_checkpoint_path:
         return external_checkpoint_path
         return external_checkpoint_path
     else:
     else:
-        checkpoints_folder_name = source_ckpt_folder_name or experiment_name
-        checkpoints_dir_path = get_checkpoints_dir_path(checkpoints_folder_name)
+        checkpoints_dir_path = get_checkpoints_dir_path(experiment_name, ckpt_root_dir)
         return os.path.join(checkpoints_dir_path, ckpt_name)
         return os.path.join(checkpoints_dir_path, ckpt_name)
Discard
@@ -151,7 +151,6 @@ class Trainer:
         self.load_weights_only = False
         self.load_weights_only = False
         self.ddp_silent_mode = is_ddp_subprocess()
         self.ddp_silent_mode = is_ddp_subprocess()
 
 
-        self.source_ckpt_folder_name = None
         self.model_weight_averaging = None
         self.model_weight_averaging = None
         self.average_model_checkpoint_filename = "average_model.pth"
         self.average_model_checkpoint_filename = "average_model.pth"
         self.start_epoch = 0
         self.start_epoch = 0
@@ -515,9 +514,8 @@ class Trainer:
 
 
         if self.training_params.average_best_models:
         if self.training_params.average_best_models:
             self.model_weight_averaging = ModelWeightAveraging(
             self.model_weight_averaging = ModelWeightAveraging(
-                self.checkpoints_dir_path,
+                ckpt_dir=self.checkpoints_dir_path,
                 greater_is_better=self.greater_metric_to_watch_is_better,
                 greater_is_better=self.greater_metric_to_watch_is_better,
-                source_ckpt_folder_name=self.source_ckpt_folder_name,
                 metric_to_watch=self.metric_to_watch,
                 metric_to_watch=self.metric_to_watch,
                 metric_idx=self.metric_idx_in_results_tuple,
                 metric_idx=self.metric_idx_in_results_tuple,
                 load_checkpoint=self.load_checkpoint,
                 load_checkpoint=self.load_checkpoint,
@@ -1481,7 +1479,6 @@ class Trainer:
 
 
          strict:           See StrictLoad class documentation for details.
          strict:           See StrictLoad class documentation for details.
          load_backbone:    loads the provided checkpoint to self.net.backbone instead of self.net
          load_backbone:    loads the provided checkpoint to self.net.backbone instead of self.net
-         source_ckpt_folder_name: The folder where the checkpoint is saved. By default uses the self.experiment_name
 
 
         NOTE: 'acc', 'epoch', 'optimizer_state_dict' and the logs are NOT loaded if self.zeroize_prev_train_params
         NOTE: 'acc', 'epoch', 'optimizer_state_dict' and the logs are NOT loaded if self.zeroize_prev_train_params
          is True
          is True
@@ -1489,8 +1486,9 @@ class Trainer:
 
 
         if self.load_checkpoint or self.external_checkpoint_path:
         if self.load_checkpoint or self.external_checkpoint_path:
             # GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
             # GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
+            ckpt_root_dir = str(Path(self.checkpoints_dir_path).parent)
             ckpt_local_path = get_ckpt_local_path(
             ckpt_local_path = get_ckpt_local_path(
-                source_ckpt_folder_name=self.source_ckpt_folder_name,
+                ckpt_root_dir=ckpt_root_dir,
                 experiment_name=self.experiment_name,
                 experiment_name=self.experiment_name,
                 ckpt_name=self.ckpt_name,
                 ckpt_name=self.ckpt_name,
                 external_checkpoint_path=self.external_checkpoint_path,
                 external_checkpoint_path=self.external_checkpoint_path,
Discard
@@ -1,8 +1,7 @@
 import os
 import os
 import torch
 import torch
 import numpy as np
 import numpy as np
-import pkg_resources
-from super_gradients.training import utils as core_utils
+from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict
 from super_gradients.training.utils.utils import move_state_dict_to_device
 from super_gradients.training.utils.utils import move_state_dict_to_device
 
 
 
 
@@ -18,7 +17,6 @@ class ModelWeightAveraging:
         self,
         self,
         ckpt_dir,
         ckpt_dir,
         greater_is_better,
         greater_is_better,
-        source_ckpt_folder_name=None,
         metric_to_watch="acc",
         metric_to_watch="acc",
         metric_idx=1,
         metric_idx=1,
         load_checkpoint=False,
         load_checkpoint=False,
@@ -26,16 +24,13 @@ class ModelWeightAveraging:
     ):
     ):
         """
         """
         Init the ModelWeightAveraging
         Init the ModelWeightAveraging
-        :param checkpoint_dir: the directory where the checkpoints are saved
+        :param ckpt_dir: the directory where the checkpoints are saved
         :param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model
         :param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model
         :param metric_idx:
         :param metric_idx:
         :param load_checkpoint: whether to load pre-existing snapshot dict.
         :param load_checkpoint: whether to load pre-existing snapshot dict.
         :param number_of_models_to_average: number of models to average
         :param number_of_models_to_average: number of models to average
         """
         """
 
 
-        if source_ckpt_folder_name is not None:
-            source_ckpt_file = os.path.join(source_ckpt_folder_name, "averaging_snapshots.pkl")
-            source_ckpt_file = pkg_resources.resource_filename("checkpoints", source_ckpt_file)
         self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
         self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
         self.number_of_models_to_average = number_of_models_to_average
         self.number_of_models_to_average = number_of_models_to_average
         self.metric_to_watch = metric_to_watch
         self.metric_to_watch = metric_to_watch
@@ -43,14 +38,8 @@ class ModelWeightAveraging:
         self.greater_is_better = greater_is_better
         self.greater_is_better = greater_is_better
 
 
         # if continuing training, copy previous snapshot dict if exist
         # if continuing training, copy previous snapshot dict if exist
-        if load_checkpoint and source_ckpt_folder_name is not None and os.path.isfile(source_ckpt_file):
-            averaging_snapshots_dict = core_utils.load_checkpoint(
-                ckpt_destination_dir=ckpt_dir,
-                source_ckpt_folder_name=source_ckpt_folder_name,
-                ckpt_filename="averaging_snapshots.pkl",
-                load_weights_only=False,
-                overwrite_local_ckpt=True,
-            )
+        if load_checkpoint and ckpt_dir is not None and os.path.isfile(self.averaging_snapshots_file):
+            averaging_snapshots_dict = read_ckpt_state_dict(self.averaging_snapshots_file)
 
 
         else:
         else:
             averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
             averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
Discard