Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#279 feature: Allow custom warmup

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:feature/ALG-577_custom-warmup-mode
@@ -55,7 +55,7 @@ from super_gradients.training.utils.checkpoint_utils import get_ckpt_local_path,
     load_checkpoint_to_model, load_pretrained_weights
     load_checkpoint_to_model, load_pretrained_weights
 from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
 from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
 from super_gradients.training.utils.callbacks import CallbackHandler, Phase, LR_SCHEDULERS_CLS_DICT, PhaseContext, \
 from super_gradients.training.utils.callbacks import CallbackHandler, Phase, LR_SCHEDULERS_CLS_DICT, PhaseContext, \
-    MetricsUpdateCallback, LR_WARMUP_CLS_DICT, ContextSgMethods
+    MetricsUpdateCallback, LR_WARMUP_CLS_DICT, ContextSgMethods, LRCallbackBase
 from super_gradients.common.environment import environment_config
 from super_gradients.common.environment import environment_config
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
 from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
@@ -907,7 +907,13 @@ class SgModel:
                                                            update_param_groups=self.update_param_groups,
                                                            update_param_groups=self.update_param_groups,
                                                            **self.training_params.to_dict()))
                                                            **self.training_params.to_dict()))
         if self.training_params.lr_warmup_epochs > 0:
         if self.training_params.lr_warmup_epochs > 0:
-            warmup_callback_cls = LR_WARMUP_CLS_DICT[self.training_params.warmup_mode]
+            warmup_mode = self.training_params.warmup_mode
+            if isinstance(warmup_mode, str):
+                warmup_callback_cls = LR_WARMUP_CLS_DICT[warmup_mode]
+            elif isinstance(warmup_mode, type) and issubclass(warmup_mode, LRCallbackBase):
+                warmup_callback_cls = warmup_mode
+            else:
+                raise RuntimeError('warmup_mode has to be either a name of a mode (str) or a subclass of PhaseCallback')
             self.phase_callbacks.append(warmup_callback_cls(train_loader_len=len(self.train_loader),
             self.phase_callbacks.append(warmup_callback_cls(train_loader_len=len(self.train_loader),
                                                             net=self.net,
                                                             net=self.net,
                                                             training_params=self.training_params,
                                                             training_params=self.training_params,
@@ -1529,7 +1535,7 @@ class SgModel:
 
 
             self.sg_logger = SG_LOGGERS[sg_logger](**sg_logger_params)
             self.sg_logger = SG_LOGGERS[sg_logger](**sg_logger_params)
         else:
         else:
-            raise RuntimeError('sg_logger can be either an sg_logger name (str) or a subcalss of AbstractSGLogger')
+            raise RuntimeError('sg_logger can be either an sg_logger name (str) or an instance of AbstractSGLogger')
 
 
         if not isinstance(self.sg_logger, BaseSGLogger):
         if not isinstance(self.sg_logger, BaseSGLogger):
             logger.warning("WARNING! Using a user-defined sg_logger: files will not be automatically written to disk!\n"
             logger.warning("WARNING! Using a user-defined sg_logger: files will not be automatically written to disk!\n"
Discard
@@ -1,9 +1,36 @@
 import unittest
 import unittest
+
+import numpy as np
+
 from super_gradients.training import SgModel
 from super_gradients.training import SgModel
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.datasets import ClassificationTestDatasetInterface
 from super_gradients.training.models import LeNet
 from super_gradients.training.models import LeNet
-from super_gradients.training.utils.callbacks import TestLRCallback
+from super_gradients.training.utils.callbacks import TestLRCallback, LRCallbackBase, Phase
+
+
+class ExponentialWarmupLRCallback(LRCallbackBase):
+    """
+    LR scheduling callback for exponential warmup.
+    LR grows exponentially from warmup_initial_lr to initial lr.
+    When warmup_initial_lr is None- LR climb starts from 0.001
+    """
+
+    def __init__(self, **kwargs):
+        super().__init__(Phase.TRAIN_EPOCH_START, **kwargs)
+        self.warmup_initial_lr = self.training_params.warmup_initial_lr or 0.001
+        warmup_epochs = self.training_params.lr_warmup_epochs
+        lr_start = self.warmup_initial_lr
+        lr_end = self.initial_lr
+        self.c1 = (lr_end - lr_start) / (np.exp(warmup_epochs) - 1.)
+        self.c2 = (lr_start * np.exp(warmup_epochs) - lr_end) / (np.exp(warmup_epochs) - 1.)
+
+    def perform_scheduling(self, context):
+        self.lr = self.c1 * np.exp(context.epoch) + self.c2
+        self.update_lr(context.optimizer, context.epoch, None)
+
+    def is_lr_scheduling_enabled(self, context):
+        return self.training_params.lr_warmup_epochs >= context.epoch
 
 
 
 
 class LRWarmupTest(unittest.TestCase):
 class LRWarmupTest(unittest.TestCase):
@@ -27,7 +54,8 @@ class LRWarmupTest(unittest.TestCase):
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
-                        "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
+                        "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
+                        "warmup_mode": "linear_step"}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
         expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
         model.train(train_params)
         model.train(train_params)
@@ -48,7 +76,8 @@ class LRWarmupTest(unittest.TestCase):
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
                         "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
-                        "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
+                        "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
+                        "warmup_mode": "linear_step"}
 
 
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
         expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
         model.train(train_params)
         model.train(train_params)
@@ -57,6 +86,50 @@ class LRWarmupTest(unittest.TestCase):
         # THE LRS AFTER THE UPDATE
         # THE LRS AFTER THE UPDATE
         self.assertListEqual(lrs, expected_lrs)
         self.assertListEqual(lrs, expected_lrs)
 
 
+    def test_warmup_initial_lr(self):
+        # Define Model
+        net = LeNet()
+        model = SgModel("test_warmup_initial_lr", model_checkpoints_location='local')
+        model.connect_dataset_interface(self.dataset)
+        model.build_model(net, arch_params=self.arch_params)
+
+        lrs = []
+        phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
+
+        train_params = {"max_epochs": 5, "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step",
+                        "lr_warmup_epochs": 3, "loss": "cross_entropy", "optimizer": 'SGD',
+                        "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
+                        "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
+                        "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
+                        "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
+                        "warmup_mode": "linear_step", "initial_lr": 1, "warmup_initial_lr": 4.}
+
+        expected_lrs = [4., 3., 2., 1., 1.]
+        model.train(train_params)
+        self.assertListEqual(lrs, expected_lrs)
+
+    def test_custom_lr_warmup(self):
+        # Define Model
+        net = LeNet()
+        model = SgModel("custom_lr_warmup_test", model_checkpoints_location='local')
+        model.connect_dataset_interface(self.dataset)
+        model.build_model(net, arch_params=self.arch_params)
+
+        lrs = []
+        phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
+
+        train_params = {"max_epochs": 5, "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step",
+                        "lr_warmup_epochs": 3, "loss": "cross_entropy", "optimizer": 'SGD',
+                        "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
+                        "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
+                        "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
+                        "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
+                        "warmup_mode": ExponentialWarmupLRCallback, "initial_lr": 1.,  "warmup_initial_lr": 0.1}
+
+        expected_lrs = [0.1, 0.18102751585334242, 0.40128313980266034, 1.0, 1.0]
+        model.train(train_params)
+        self.assertListEqual(lrs, expected_lrs)
+
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()
Discard