Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#257 allow using an external Optimizer (not initialized outside)

Merged
Ofri Masad merged 1 commits into Deci-AI:master from deci-ai:feature/SG-184_external_optimizer
@@ -1,3 +1,4 @@
+import inspect
 import os
 import os
 import sys
 import sys
 from copy import deepcopy
 from copy import deepcopy
@@ -954,7 +955,8 @@ class SgModel:
             self._reset_best_metric()
             self._reset_best_metric()
             load_opt_params = False
             load_opt_params = False
 
 
-        if isinstance(self.training_params.optimizer, str):
+        if isinstance(self.training_params.optimizer, str) or \
+                (inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer)):
             self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr,
             self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr,
                                              training_params=self.training_params)
                                              training_params=self.training_params)
         elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
         elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
Discard
@@ -79,7 +79,10 @@ def build_optimizer(net, lr, training_params):
         :param lr: initial learning rate
         :param lr: initial learning rate
         :param training_params: training_parameters
         :param training_params: training_parameters
     """
     """
-    optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
+    if isinstance(training_params.optimizer, str):
+        optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
+    else:
+        optimizer_cls = training_params.optimizer
     default_optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls] if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS else {}
     default_optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls] if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS else {}
     training_params.optimizer_params = get_param(training_params, 'optimizer_params', default_optimizer_params)
     training_params.optimizer_params = get_param(training_params, 'optimizer_params', default_optimizer_params)
 
 
Discard
@@ -71,6 +71,24 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         model.train(train_params)
         model.train(train_params)
         assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
         assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
 
 
+    def test_train_with_external_scheduler_class(self):
+        model = SgModel("external_scheduler_test", model_checkpoints_location='local')
+        dataset_params = {"batch_size": 10}
+        dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
+        model.connect_dataset_interface(dataset)
+
+        net = ResNet18(num_classes=5, arch_params={})
+        optimizer = SGD # a class - not an instance
+        model.build_model(net)
+
+        train_params = {"max_epochs": 2,
+                        "lr_warmup_epochs": 0, "initial_lr": 0.3, "loss": "cross_entropy", "optimizer": optimizer,
+                        "criterion_params": {},
+                        "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
+                        "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
+                        "greater_metric_to_watch_is_better": True}
+        model.train(train_params)
+
     def test_train_with_reduce_on_plateau(self):
     def test_train_with_reduce_on_plateau(self):
         model = SgModel("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
         model = SgModel("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
         dataset_params = {"batch_size": 10}
         dataset_params = {"batch_size": 10}
Discard