Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

lr_warmup_test.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. import unittest
  2. from super_gradients.training import SgModel
  3. from super_gradients.training.metrics import Accuracy
  4. from super_gradients.training.datasets import ClassificationTestDatasetInterface
  5. from super_gradients.training.models import LeNet
  6. from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
  7. class TestLRCallback(PhaseCallback):
  8. """
  9. Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
  10. the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
  11. one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
  12. """
  13. def __init__(self, lr_placeholder):
  14. super(TestLRCallback, self).__init__(Phase.VALIDATION_EPOCH_END)
  15. self.lr_placeholder = lr_placeholder
  16. def __call__(self, context: PhaseContext):
  17. self.lr_placeholder.append(context.optimizer.param_groups[0]['lr'])
  18. class LRWarmupTest(unittest.TestCase):
  19. def setUp(self) -> None:
  20. self.dataset_params = {"batch_size": 4}
  21. self.dataset = ClassificationTestDatasetInterface(dataset_params=self.dataset_params)
  22. self.arch_params = {'num_classes': 10}
  23. def test_lr_warmup(self):
  24. # Define Model
  25. net = LeNet()
  26. model = SgModel("lr_warmup_test", model_checkpoints_location='local')
  27. model.connect_dataset_interface(self.dataset)
  28. model.build_model(net, arch_params=self.arch_params)
  29. lrs = []
  30. phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
  31. train_params = {"max_epochs": 5, "lr_updates": [], "lr_decay_factor": 0.1, "lr_mode": "step",
  32. "lr_warmup_epochs": 3, "initial_lr": 1, "loss": "cross_entropy", "optimizer": 'SGD',
  33. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  34. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  35. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
  36. "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
  37. expected_lrs = [0.25, 0.5, 0.75, 1.0, 1.0]
  38. model.train(train_params)
  39. self.assertListEqual(lrs, expected_lrs)
  40. def test_lr_warmup_with_lr_scheduling(self):
  41. # Define Model
  42. net = LeNet()
  43. model = SgModel("lr_warmup_test", model_checkpoints_location='local')
  44. model.connect_dataset_interface(self.dataset)
  45. model.build_model(net, arch_params=self.arch_params)
  46. lrs = []
  47. phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
  48. train_params = {"max_epochs": 5, "cosine_final_lr_ratio": 0.2, "lr_mode": "cosine",
  49. "lr_warmup_epochs": 3, "initial_lr": 1, "loss": "cross_entropy", "optimizer": 'SGD',
  50. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  51. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  52. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
  53. "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks}
  54. expected_lrs = [0.25, 0.5, 0.75, 0.9236067977499791, 0.4763932022500211]
  55. model.train(train_params)
  56. # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
  57. # THE LRS AFTER THE UPDATE
  58. self.assertListEqual(lrs, expected_lrs)
  59. if __name__ == '__main__':
  60. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...