Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#561 Feature/sg 193 extend output formator

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-193-extend_detection_target_transform
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. import os
  2. import unittest
  3. from copy import deepcopy
  4. from super_gradients.training import Trainer
  5. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  6. from super_gradients.training.metrics import Accuracy, Top5
  7. from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
  8. from super_gradients.training.utils.utils import check_models_have_same_weights
  9. from super_gradients.training.models import LeNet
  10. class FirstEpochInfoCollector(PhaseCallback):
  11. def __init__(self):
  12. super().__init__(phase=Phase.TRAIN_EPOCH_START)
  13. self.called = False
  14. self.first_epoch = None
  15. self.first_epoch_net = None
  16. def __call__(self, context: PhaseContext):
  17. if not self.called:
  18. self.first_epoch = context.epoch
  19. self.first_epoch_net = deepcopy(context.net)
  20. self.called = True
  21. class ResumeTrainingTest(unittest.TestCase):
  22. def test_resume_training(self):
  23. train_params = {
  24. "max_epochs": 2,
  25. "lr_updates": [1],
  26. "lr_decay_factor": 0.1,
  27. "lr_mode": "step",
  28. "lr_warmup_epochs": 0,
  29. "initial_lr": 0.1,
  30. "loss": "cross_entropy",
  31. "optimizer": "SGD",
  32. "criterion_params": {},
  33. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  34. "train_metrics_list": [Accuracy(), Top5()],
  35. "valid_metrics_list": [Accuracy(), Top5()],
  36. "metric_to_watch": "Accuracy",
  37. "greater_metric_to_watch_is_better": True,
  38. }
  39. # Define Model
  40. net = LeNet()
  41. trainer = Trainer("test_resume_training")
  42. trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  43. # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
  44. resume_net = LeNet()
  45. trainer = Trainer("test_resume_training")
  46. first_epoch_cb = FirstEpochInfoCollector()
  47. train_params["resume"] = True
  48. train_params["max_epochs"] = 3
  49. train_params["phase_callbacks"] = [first_epoch_cb]
  50. trainer.train(
  51. model=resume_net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
  52. )
  53. # ASSERT RELOADED MODEL HAS THE SAME WEIGHTS AS THE MODEL SAVED IN FIRST PART OF TRAINING
  54. self.assertTrue(check_models_have_same_weights(net, first_epoch_cb.first_epoch_net))
  55. # ASSERT WE START FROM THE RIGHT EPOCH NUMBER
  56. self.assertTrue(first_epoch_cb.first_epoch == 2)
  57. def test_resume_external_training(self):
  58. train_params = {
  59. "max_epochs": 2,
  60. "lr_updates": [1],
  61. "lr_decay_factor": 0.1,
  62. "lr_mode": "step",
  63. "lr_warmup_epochs": 0,
  64. "initial_lr": 0.1,
  65. "loss": "cross_entropy",
  66. "optimizer": "SGD",
  67. "criterion_params": {},
  68. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  69. "train_metrics_list": [Accuracy(), Top5()],
  70. "valid_metrics_list": [Accuracy(), Top5()],
  71. "metric_to_watch": "Accuracy",
  72. "greater_metric_to_watch_is_better": True,
  73. }
  74. # Define Model
  75. net = LeNet()
  76. trainer = Trainer("test_resume_training")
  77. trainer.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  78. # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
  79. resume_net = LeNet()
  80. resume_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_latest.pth")
  81. # SET DIFFERENT EXPERIMENT NAME SO WE LOAD A CHECKPOINT THAT HAS A DIFFERENT PATH FROM THE DEFAULT ONE
  82. trainer = Trainer("test_resume_external_training")
  83. first_epoch_cb = FirstEpochInfoCollector()
  84. train_params["resume_path"] = resume_path
  85. train_params["max_epochs"] = 3
  86. train_params["phase_callbacks"] = [first_epoch_cb]
  87. trainer.train(
  88. model=resume_net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
  89. )
  90. # ASSERT RELOADED MODEL HAS THE SAME WEIGHTS AS THE MODEL SAVED IN FIRST PART OF TRAINING
  91. self.assertTrue(check_models_have_same_weights(net, first_epoch_cb.first_epoch_net))
  92. # ASSERT WE START FROM THE RIGHT EPOCH NUMBER
  93. self.assertTrue(first_epoch_cb.first_epoch == 2)
  94. if __name__ == "__main__":
  95. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file