Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#452 Bug/sg 318 default optimizer params not taken with zero wd on bn and bias

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-318_default_optimizer_params_not_taken_with_zero_wd_on_bn_and_bias
@@ -1,72 +1,75 @@
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from copy import deepcopy
 from copy import deepcopy
-DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
-                           "lr_cooldown_epochs": 0,
-                           "warmup_initial_lr": None,
-                           "cosine_final_lr_ratio": 0.01,
-                           "optimizer": "SGD",
-                           "criterion_params": {},
-                           "ema": False,
-                           "batch_accumulate": 1,  # number of batches to accumulate before every backward pass
-                           "ema_params": {},
-                           "zero_weight_decay_on_bias_and_bn": False,
-                           "load_opt_params": True,
-                           "run_validation_freq": 1,
-                           "save_model": True,
-                           "metric_to_watch": "Accuracy",
-                           "launch_tensorboard": False,
-                           "tb_files_user_prompt": False,  # Asks User for Tensorboard Deletion Prompt
-                           "silent_mode": False,  # Silents the Print outs
-                           "mixed_precision": False,
-                           "tensorboard_port": None,
-                           "save_ckpt_epoch_list": [],  # indices where the ckpt will save automatically
-                           "average_best_models": True,
-                           "dataset_statistics": False,  # add a dataset statistical analysis and sample images to tensorboard
-                           "save_tensorboard_to_s3": False,
-                           "lr_schedule_function": None,
-                           "train_metrics_list": [],
-                           "valid_metrics_list": [],
 
 
-                           "greater_metric_to_watch_is_better": True,
-                           "precise_bn": False,
-                           "precise_bn_batch_size": None,
-                           "seed": 42,
-                           "lr_mode": None,
-                           "phase_callbacks": None,
-                           "log_installed_packages": True,
-                           "save_full_train_log": False,
-                           "sg_logger": "base_sg_logger",
-                           "sg_logger_params":
-                               {"tb_files_user_prompt": False,  # Asks User for Tensorboard Deletion Prompt
-                                "project_name": "",
-                                "launch_tensorboard": False,
-                                "tensorboard_port": None,
-                                "save_checkpoints_remote": False,  # upload checkpoint files to s3
-                                "save_tensorboard_remote": False,  # upload tensorboard files to s3
-                                "save_logs_remote": False},  # upload log files to s3
-                           "warmup_mode": "linear_step",
-                           "step_lr_update_freq": None,
-                           "lr_updates": [],
-                           'clip_grad_norm': None,
-                           'pre_prediction_callback': None,
-                           'ckpt_best_name': 'ckpt_best.pth',
-                           'enable_qat': False,
-                           'qat_params': {
-                               "start_epoch": 0,
-                               "quant_modules_calib_method": "percentile",
-                               "per_channel_quant_modules": False,
-                               "calibrate": True,
-                               "calibrated_model_path": None,
-                               "calib_data_loader": None,
-                               "num_calib_batches": 2,
-                               "percentile": 99.99
-                           },
-                           "resume": False,
-                           "resume_path": None,
-                           "ckpt_name": 'ckpt_latest.pth',
-                           "resume_strict_load": False,
-                           "sync_bn": False
-                           }
+DEFAULT_TRAINING_PARAMS = {
+    "lr_warmup_epochs": 0,
+    "lr_cooldown_epochs": 0,
+    "warmup_initial_lr": None,
+    "cosine_final_lr_ratio": 0.01,
+    "optimizer": "SGD",
+    "optimizer_params": {},
+    "criterion_params": {},
+    "ema": False,
+    "batch_accumulate": 1,  # number of batches to accumulate before every backward pass
+    "ema_params": {},
+    "zero_weight_decay_on_bias_and_bn": False,
+    "load_opt_params": True,
+    "run_validation_freq": 1,
+    "save_model": True,
+    "metric_to_watch": "Accuracy",
+    "launch_tensorboard": False,
+    "tb_files_user_prompt": False,  # Asks User for Tensorboard Deletion Prompt
+    "silent_mode": False,  # Silents the Print outs
+    "mixed_precision": False,
+    "tensorboard_port": None,
+    "save_ckpt_epoch_list": [],  # indices where the ckpt will save automatically
+    "average_best_models": True,
+    "dataset_statistics": False,  # add a dataset statistical analysis and sample images to tensorboard
+    "save_tensorboard_to_s3": False,
+    "lr_schedule_function": None,
+    "train_metrics_list": [],
+    "valid_metrics_list": [],
+    "greater_metric_to_watch_is_better": True,
+    "precise_bn": False,
+    "precise_bn_batch_size": None,
+    "seed": 42,
+    "lr_mode": None,
+    "phase_callbacks": None,
+    "log_installed_packages": True,
+    "save_full_train_log": False,
+    "sg_logger": "base_sg_logger",
+    "sg_logger_params": {
+        "tb_files_user_prompt": False,  # Asks User for Tensorboard Deletion Prompt
+        "project_name": "",
+        "launch_tensorboard": False,
+        "tensorboard_port": None,
+        "save_checkpoints_remote": False,  # upload checkpoint files to s3
+        "save_tensorboard_remote": False,  # upload tensorboard files to s3
+        "save_logs_remote": False,
+    },  # upload log files to s3
+    "warmup_mode": "linear_step",
+    "step_lr_update_freq": None,
+    "lr_updates": [],
+    "clip_grad_norm": None,
+    "pre_prediction_callback": None,
+    "ckpt_best_name": "ckpt_best.pth",
+    "enable_qat": False,
+    "qat_params": {
+        "start_epoch": 0,
+        "quant_modules_calib_method": "percentile",
+        "per_channel_quant_modules": False,
+        "calibrate": True,
+        "calibrated_model_path": None,
+        "calib_data_loader": None,
+        "num_calib_batches": 2,
+        "percentile": 99.99,
+    },
+    "resume": False,
+    "resume_path": None,
+    "ckpt_name": "ckpt_latest.pth",
+    "resume_strict_load": False,
+    "sync_bn": False,
+}
 
 
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
 
 
@@ -76,29 +79,23 @@ DEFAULT_OPTIMIZER_PARAMS_RMSPROP = {"weight_decay": 1e-4, "momentum": 0.9}
 
 
 DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF = {"weight_decay": 1e-4, "momentum": 0.9}
 DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF = {"weight_decay": 1e-4, "momentum": 0.9}
 
 
-TRAINING_PARAM_SCHEMA = {"type": "object",
-                         "properties": {
-                             "max_epochs": {"type": "number", "minimum": 1},
-
-                             # FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH
-
-                             # "lr_updates": {"type": "array", "minItems": 1},
-                             "lr_decay_factor": {"type": "number", "minimum": 0, "maximum": 1},
-                             "lr_warmup_epochs": {"type": "number", "minimum": 0, "maximum": 10},
-                             "initial_lr": {"type": "number", "exclusiveMinimum": 0, "maximum": 10}
-                         },
-                         "if": {
-                             "properties": {"lr_mode": {"const": "step"}}
-                         },
-                         "then": {
-                             "required": ["lr_updates", "lr_decay_factor"]
-                         },
-                         "required": ["max_epochs", "lr_mode", "initial_lr", "loss"]
-                         }
+TRAINING_PARAM_SCHEMA = {
+    "type": "object",
+    "properties": {
+        "max_epochs": {"type": "number", "minimum": 1},
+        # FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH
+        # "lr_updates": {"type": "array", "minItems": 1},
+        "lr_decay_factor": {"type": "number", "minimum": 0, "maximum": 1},
+        "lr_warmup_epochs": {"type": "number", "minimum": 0, "maximum": 10},
+        "initial_lr": {"type": "number", "exclusiveMinimum": 0, "maximum": 10},
+    },
+    "if": {"properties": {"lr_mode": {"const": "step"}}},
+    "then": {"required": ["lr_updates", "lr_decay_factor"]},
+    "required": ["max_epochs", "lr_mode", "initial_lr", "loss"],
+}
 
 
 
 
 class TrainingParams(HpmStruct):
 class TrainingParams(HpmStruct):
-
     def __init__(self, **entries):
     def __init__(self, **entries):
         # WE initialize by the default training params, overridden by the provided params
         # WE initialize by the default training params, overridden by the provided params
         default_training_params = deepcopy(DEFAULT_TRAINING_PARAMS)
         default_training_params = deepcopy(DEFAULT_TRAINING_PARAMS)
Discard
@@ -4,17 +4,23 @@ from torch.nn.modules.batchnorm import _BatchNorm
 from torch.nn.modules.conv import _ConvNd
 from torch.nn.modules.conv import _ConvNd
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
 from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
-from super_gradients.training.params import DEFAULT_OPTIMIZER_PARAMS_SGD, DEFAULT_OPTIMIZER_PARAMS_ADAM, \
-    DEFAULT_OPTIMIZER_PARAMS_RMSPROP, DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF
+from super_gradients.training.params import (
+    DEFAULT_OPTIMIZER_PARAMS_SGD,
+    DEFAULT_OPTIMIZER_PARAMS_ADAM,
+    DEFAULT_OPTIMIZER_PARAMS_RMSPROP,
+    DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF,
+)
 from super_gradients.training.utils import get_param
 from super_gradients.training.utils import get_param
 from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
 from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
-OPTIMIZERS_DEFAULT_PARAMS = {optim.SGD: DEFAULT_OPTIMIZER_PARAMS_SGD,
-                             optim.Adam: DEFAULT_OPTIMIZER_PARAMS_ADAM,
-                             optim.RMSprop: DEFAULT_OPTIMIZER_PARAMS_RMSPROP,
-                             RMSpropTF: DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF}
+OPTIMIZERS_DEFAULT_PARAMS = {
+    optim.SGD: DEFAULT_OPTIMIZER_PARAMS_SGD,
+    optim.Adam: DEFAULT_OPTIMIZER_PARAMS_ADAM,
+    optim.RMSprop: DEFAULT_OPTIMIZER_PARAMS_RMSPROP,
+    RMSpropTF: DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF,
+}
 
 
 
 
 def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float):
 def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float):
@@ -40,8 +46,7 @@ def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_pa
             else:
             else:
                 decay_params.append(param)
                 decay_params.append(param)
         # append two param groups from the original param group, with and without weight decay.
         # append two param groups from the original param group, with and without weight decay.
-        extra_optim_params = {key: param_group[key] for key in param_group
-                              if key not in ["named_params", "weight_decay"]}
+        extra_optim_params = {key: param_group[key] for key in param_group if key not in ["named_params", "weight_decay"]}
         optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
         optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
         optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})
         optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})
 
 
@@ -65,9 +70,11 @@ def _get_no_decay_param_ids(module: nn.Module):
             no_decay_ids.append(id(m.bias))
             no_decay_ids.append(id(m.bias))
         elif hasattr(m, "bias") and isinstance(m.bias, nn.Parameter):
         elif hasattr(m, "bias") and isinstance(m.bias, nn.Parameter):
             if not isinstance(m, torch_weight_with_bias_types):
             if not isinstance(m, torch_weight_with_bias_types):
-                logger.warning(f"Module class: {m.__class__}, have a `bias` parameter attribute but is not instance of"
-                               f" torch primitive modules, this bias parameter will be part of param group with zero"
-                               f" weight decay.")
+                logger.warning(
+                    f"Module class: {m.__class__}, have a `bias` parameter attribute but is not instance of"
+                    f" torch primitive modules, this bias parameter will be part of param group with zero"
+                    f" weight decay."
+                )
             no_decay_ids.append(id(m.bias))
             no_decay_ids.append(id(m.bias))
     return no_decay_ids
     return no_decay_ids
 
 
@@ -83,27 +90,26 @@ def build_optimizer(net, lr, training_params):
         optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
         optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
     else:
     else:
         optimizer_cls = training_params.optimizer
         optimizer_cls = training_params.optimizer
-    default_optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls] if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS else {}
-    training_params.optimizer_params = get_param(training_params, 'optimizer_params', default_optimizer_params)
+    optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls].copy() if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS.keys() else dict()
+    optimizer_params.update(**training_params.optimizer_params)
+    training_params.optimizer_params = optimizer_params
 
 
-    weight_decay = get_param(training_params.optimizer_params, 'weight_decay', 0.)
+    weight_decay = get_param(training_params.optimizer_params, "weight_decay", 0.0)
     # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
     # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
-    if hasattr(net.module, 'initialize_param_groups'):
+    if hasattr(net.module, "initialize_param_groups"):
         # INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP
         # INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP
         net_named_params = net.module.initialize_param_groups(lr, training_params)
         net_named_params = net.module.initialize_param_groups(lr, training_params)
     else:
     else:
-        net_named_params = [{'named_params': net.named_parameters()}]
+        net_named_params = [{"named_params": net.named_parameters()}]
 
 
     if training_params.zero_weight_decay_on_bias_and_bn:
     if training_params.zero_weight_decay_on_bias_and_bn:
-        optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(
-            net.module, net_named_params, weight_decay
-        )
+        optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net.module, net_named_params, weight_decay)
 
 
     else:
     else:
         # Overwrite groups to include params instead of named params
         # Overwrite groups to include params instead of named params
         for ind_group, param_group in enumerate(net_named_params):
         for ind_group, param_group in enumerate(net_named_params):
-            param_group['params'] = [param[1] for param in list(param_group['named_params'])]
-            del param_group['named_params']
+            param_group["params"] = [param[1] for param in list(param_group["named_params"])]
+            del param_group["named_params"]
             net_named_params[ind_group] = param_group
             net_named_params[ind_group] = param_group
         optimizer_training_params = net_named_params
         optimizer_training_params = net_named_params
 
 
Discard
@@ -2,9 +2,21 @@ import sys
 import unittest
 import unittest
 
 
 from tests.integration_tests.ema_train_integration_test import EMAIntegrationTest
 from tests.integration_tests.ema_train_integration_test import EMAIntegrationTest
-from tests.unit_tests import ZeroWdForBnBiasTest, SaveCkptListUnitTest, TestAverageMeter, \
-    TestRepVgg, TestWithoutTrainTest, OhemLossTest, EarlyStopTest, SegmentationTransformsTest, \
-    TestConvBnRelu, FactoriesTest, InitializeWithDataloadersTest, TrainingParamsTest
+from tests.unit_tests import (
+    ZeroWdForBnBiasTest,
+    SaveCkptListUnitTest,
+    TestAverageMeter,
+    TestRepVgg,
+    TestWithoutTrainTest,
+    OhemLossTest,
+    EarlyStopTest,
+    SegmentationTransformsTest,
+    TestConvBnRelu,
+    FactoriesTest,
+    InitializeWithDataloadersTest,
+    TrainingParamsTest,
+    TrainOptimizerParamsOverride,
+)
 from tests.end_to_end_tests import TestTrainer
 from tests.end_to_end_tests import TestTrainer
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
@@ -30,7 +42,6 @@ from tests.unit_tests.multi_scaling_test import MultiScaleTest
 
 
 
 
 class CoreUnitTestSuiteRunner:
 class CoreUnitTestSuiteRunner:
-
     def __init__(self):
     def __init__(self):
         self.test_loader = unittest.TestLoader()
         self.test_loader = unittest.TestLoader()
         self.unit_tests_suite = unittest.TestSuite()
         self.unit_tests_suite = unittest.TestSuite()
@@ -77,6 +88,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetCaching))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetCaching))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainingParamsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainingParamsTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainOptimizerParamsOverride))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
@@ -87,5 +99,5 @@ class CoreUnitTestSuiteRunner:
         self.end_to_end_tests_suite.addTest(self.test_loader.loadTestsFromModule(EMAIntegrationTest))
         self.end_to_end_tests_suite.addTest(self.test_loader.loadTestsFromModule(EMAIntegrationTest))
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard
@@ -1,5 +1,6 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 from tests.unit_tests.factories_test import FactoriesTest
 from tests.unit_tests.factories_test import FactoriesTest
+from tests.unit_tests.optimizer_params_override_test import TrainOptimizerParamsOverride
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
 from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
 from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest
 from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest
@@ -17,8 +18,23 @@ from tests.unit_tests.conv_bn_relu_test import TestConvBnRelu
 from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
 from tests.unit_tests.initialize_with_dataloaders_test import InitializeWithDataloadersTest
 from tests.unit_tests.training_params_factory_test import TrainingParamsTest
 from tests.unit_tests.training_params_factory_test import TrainingParamsTest
 
 
-__all__ = ['ZeroWdForBnBiasTest', 'SaveCkptListUnitTest',
-           'AllArchitecturesTest', 'TestAverageMeter', 'TestRepVgg', 'TestWithoutTrainTest',
-           'StrictLoadEnumTest', 'TrainWithInitializedObjectsTest', 'TestAutoAugment',
-           'OhemLossTest', 'EarlyStopTest', 'SegmentationTransformsTest', 'PretrainedModelsUnitTest', 'TestConvBnRelu',
-           'FactoriesTest', 'InitializeWithDataloadersTest', 'TrainingParamsTest']
+__all__ = [
+    "ZeroWdForBnBiasTest",
+    "SaveCkptListUnitTest",
+    "AllArchitecturesTest",
+    "TestAverageMeter",
+    "TestRepVgg",
+    "TestWithoutTrainTest",
+    "StrictLoadEnumTest",
+    "TrainWithInitializedObjectsTest",
+    "TestAutoAugment",
+    "OhemLossTest",
+    "EarlyStopTest",
+    "SegmentationTransformsTest",
+    "PretrainedModelsUnitTest",
+    "TestConvBnRelu",
+    "FactoriesTest",
+    "InitializeWithDataloadersTest",
+    "TrainingParamsTest",
+    "TrainOptimizerParamsOverride",
+]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  1. import unittest
  2. from super_gradients.training.utils.utils import get_param
  3. from super_gradients import Trainer
  4. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  5. from super_gradients.training.metrics import Accuracy, Top5
  6. from super_gradients.training.models import ResNet18
  7. class TrainOptimizerParamsOverride(unittest.TestCase):
  8. def test_optimizer_params_partial_override(self):
  9. trainer = Trainer("test_optimizer_params_partial_override")
  10. net = ResNet18(num_classes=5, arch_params={})
  11. train_params = {
  12. "max_epochs": 1,
  13. "lr_updates": [1],
  14. "lr_decay_factor": 0.1,
  15. "lr_mode": "step",
  16. "lr_warmup_epochs": 0,
  17. "initial_lr": 0.1,
  18. "loss": "cross_entropy",
  19. "optimizer": "SGD",
  20. "criterion_params": {},
  21. "optimizer_params": {"momentum": 0.9},
  22. "zero_weight_decay_on_bias_and_bn": True,
  23. "train_metrics_list": [Accuracy(), Top5()],
  24. "valid_metrics_list": [Accuracy(), Top5()],
  25. "metric_to_watch": "Accuracy",
  26. "greater_metric_to_watch_is_better": True,
  27. }
  28. trainer.train(
  29. model=net,
  30. training_params=train_params,
  31. train_loader=classification_test_dataloader(batch_size=10),
  32. valid_loader=classification_test_dataloader(batch_size=10),
  33. )
  34. self.assertTrue(get_param(trainer.training_params.optimizer_params, "weight_decay"), 1e-4)
  35. self.assertTrue(get_param(trainer.training_params.optimizer_params, "momentum"), 0.9)
  36. def test_optimizer_params_full_override(self):
  37. trainer = Trainer("test_optimizer_params_full_override")
  38. net = ResNet18(num_classes=5, arch_params={})
  39. train_params = {
  40. "max_epochs": 1,
  41. "lr_updates": [1],
  42. "lr_decay_factor": 0.1,
  43. "lr_mode": "step",
  44. "lr_warmup_epochs": 0,
  45. "initial_lr": 0.1,
  46. "loss": "cross_entropy",
  47. "optimizer": "SGD",
  48. "criterion_params": {},
  49. "zero_weight_decay_on_bias_and_bn": True,
  50. "train_metrics_list": [Accuracy(), Top5()],
  51. "valid_metrics_list": [Accuracy(), Top5()],
  52. "metric_to_watch": "Accuracy",
  53. "greater_metric_to_watch_is_better": True,
  54. }
  55. trainer.train(
  56. model=net,
  57. training_params=train_params,
  58. train_loader=classification_test_dataloader(batch_size=10),
  59. valid_loader=classification_test_dataloader(batch_size=10),
  60. )
  61. self.assertTrue(get_param(trainer.training_params.optimizer_params, "weight_decay"), 1e-4)
  62. self.assertTrue(get_param(trainer.training_params.optimizer_params, "momentum"), 0.9)
Discard