|
@@ -151,7 +151,6 @@ class Trainer:
|
|
self.load_weights_only = False
|
|
self.load_weights_only = False
|
|
self.ddp_silent_mode = is_ddp_subprocess()
|
|
self.ddp_silent_mode = is_ddp_subprocess()
|
|
|
|
|
|
- self.source_ckpt_folder_name = None
|
|
|
|
self.model_weight_averaging = None
|
|
self.model_weight_averaging = None
|
|
self.average_model_checkpoint_filename = "average_model.pth"
|
|
self.average_model_checkpoint_filename = "average_model.pth"
|
|
self.start_epoch = 0
|
|
self.start_epoch = 0
|
|
@@ -515,9 +514,8 @@ class Trainer:
|
|
|
|
|
|
if self.training_params.average_best_models:
|
|
if self.training_params.average_best_models:
|
|
self.model_weight_averaging = ModelWeightAveraging(
|
|
self.model_weight_averaging = ModelWeightAveraging(
|
|
- self.checkpoints_dir_path,
|
|
|
|
|
|
+ ckpt_dir=self.checkpoints_dir_path,
|
|
greater_is_better=self.greater_metric_to_watch_is_better,
|
|
greater_is_better=self.greater_metric_to_watch_is_better,
|
|
- source_ckpt_folder_name=self.source_ckpt_folder_name,
|
|
|
|
metric_to_watch=self.metric_to_watch,
|
|
metric_to_watch=self.metric_to_watch,
|
|
metric_idx=self.metric_idx_in_results_tuple,
|
|
metric_idx=self.metric_idx_in_results_tuple,
|
|
load_checkpoint=self.load_checkpoint,
|
|
load_checkpoint=self.load_checkpoint,
|
|
@@ -1481,7 +1479,6 @@ class Trainer:
|
|
|
|
|
|
strict: See StrictLoad class documentation for details.
|
|
strict: See StrictLoad class documentation for details.
|
|
load_backbone: loads the provided checkpoint to self.net.backbone instead of self.net
|
|
load_backbone: loads the provided checkpoint to self.net.backbone instead of self.net
|
|
- source_ckpt_folder_name: The folder where the checkpoint is saved. By default uses the self.experiment_name
|
|
|
|
|
|
|
|
NOTE: 'acc', 'epoch', 'optimizer_state_dict' and the logs are NOT loaded if self.zeroize_prev_train_params
|
|
NOTE: 'acc', 'epoch', 'optimizer_state_dict' and the logs are NOT loaded if self.zeroize_prev_train_params
|
|
is True
|
|
is True
|
|
@@ -1489,8 +1486,9 @@ class Trainer:
|
|
|
|
|
|
if self.load_checkpoint or self.external_checkpoint_path:
|
|
if self.load_checkpoint or self.external_checkpoint_path:
|
|
# GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
|
|
# GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
|
|
|
|
+ ckpt_root_dir = str(Path(self.checkpoints_dir_path).parent)
|
|
ckpt_local_path = get_ckpt_local_path(
|
|
ckpt_local_path = get_ckpt_local_path(
|
|
- source_ckpt_folder_name=self.source_ckpt_folder_name,
|
|
|
|
|
|
+ ckpt_root_dir=ckpt_root_dir,
|
|
experiment_name=self.experiment_name,
|
|
experiment_name=self.experiment_name,
|
|
ckpt_name=self.ckpt_name,
|
|
ckpt_name=self.ckpt_name,
|
|
external_checkpoint_path=self.external_checkpoint_path,
|
|
external_checkpoint_path=self.external_checkpoint_path,
|