Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

trainer_test.py 6.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  1. import shutil
  2. import unittest
  3. from super_gradients.common.object_names import Models
  4. from super_gradients.training import models
  5. import super_gradients
  6. import torch
  7. import os
  8. from super_gradients import Trainer
  9. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  10. from super_gradients.training.metrics import Accuracy, Top5
  11. from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
  12. class TestTrainer(unittest.TestCase):
  13. @classmethod
  14. def setUp(cls):
  15. super_gradients.init_trainer()
  16. # NAMES FOR THE EXPERIMENTS TO LATER DELETE
  17. cls.experiment_names = [
  18. "test_train",
  19. "test_save_load",
  20. "test_load_w",
  21. "test_load_w2",
  22. "test_load_w3",
  23. "test_checkpoint_content",
  24. "analyze",
  25. "test_yaml_metrics_present",
  26. ]
  27. cls.training_params = {
  28. "max_epochs": 1,
  29. "silent_mode": True,
  30. "lr_decay_factor": 0.1,
  31. "initial_lr": 0.1,
  32. "lr_updates": [4],
  33. "lr_mode": "StepLRScheduler",
  34. "loss": "CrossEntropyLoss",
  35. "train_metrics_list": [Accuracy(), Top5()],
  36. "valid_metrics_list": [Accuracy(), Top5()],
  37. "metric_to_watch": "Accuracy",
  38. "greater_metric_to_watch_is_better": True,
  39. }
  40. @classmethod
  41. def tearDownClass(cls) -> None:
  42. # ERASE ALL THE EXPERIMENT FOLDERS THAT WERE CREATED DURING THIS TEST
  43. for experiment_name in cls.experiment_names:
  44. experiment_dir = get_checkpoints_dir_path(experiment_name=experiment_name)
  45. if os.path.isdir(experiment_dir):
  46. # TODO: Occasionally this method fails because log files are still open (See setup_logging() call).
  47. # TODO: Need to find a way to close them at the end of training, this is however tricky to achieve
  48. # TODO: because setup_logging() called outside of Trainer class.
  49. shutil.rmtree(experiment_dir, ignore_errors=True)
  50. @staticmethod
  51. def get_classification_trainer(name=""):
  52. trainer = Trainer(name)
  53. model = models.get(Models.RESNET18, num_classes=5)
  54. return trainer, model
  55. def test_train(self):
  56. trainer, model = self.get_classification_trainer(self.experiment_names[0])
  57. trainer.train(
  58. model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
  59. )
  60. def test_save_load(self):
  61. trainer, model = self.get_classification_trainer(self.experiment_names[1])
  62. trainer.train(
  63. model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
  64. )
  65. resume_training_params = self.training_params.copy()
  66. resume_training_params["resume"] = True
  67. resume_training_params["max_epochs"] = 2
  68. trainer, model = self.get_classification_trainer(self.experiment_names[1])
  69. trainer.train(
  70. model=model, training_params=resume_training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
  71. )
  72. def test_checkpoint_content(self):
  73. """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
  74. trainer, model = self.get_classification_trainer(self.experiment_names[5])
  75. params = self.training_params.copy()
  76. params["save_ckpt_epoch_list"] = [1]
  77. trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  78. ckpt_filename = ["ckpt_best.pth", "ckpt_latest.pth", "ckpt_epoch_1.pth"]
  79. ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
  80. for ckpt_path in ckpt_paths:
  81. ckpt = torch.load(ckpt_path)
  82. self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "metrics", "packages"]), sorted(list(ckpt.keys())))
  83. trainer._save_checkpoint()
  84. weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
  85. self.assertListEqual(["net"], list(weights_only.keys()))
  86. def test_validation_frequency_divisible(self):
  87. trainer, model = self.get_classification_trainer(self.experiment_names[0])
  88. training_params = self.training_params.copy()
  89. training_params["max_epochs"] = 4
  90. training_params["run_validation_freq"] = 2
  91. trainer.train(
  92. model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
  93. )
  94. ckpt_filename = ["ckpt_best.pth", "ckpt_latest.pth"]
  95. ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
  96. metrics = {}
  97. for ckpt_path in ckpt_paths:
  98. ckpt = torch.load(ckpt_path)
  99. metrics[ckpt_path] = ckpt["metrics"]
  100. self.assertTrue(metrics[ckpt_paths[0]] == metrics[ckpt_paths[1]])
  101. def test_validation_frequency_and_save_ckpt_list(self):
  102. trainer, model = self.get_classification_trainer(self.experiment_names[0])
  103. training_params = self.training_params.copy()
  104. training_params["max_epochs"] = 5
  105. training_params["run_validation_freq"] = 3
  106. training_params["save_ckpt_epoch_list"] = [1]
  107. trainer.train(
  108. model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
  109. )
  110. ckpt_filename = ["ckpt_epoch_1.pth", "ckpt_latest.pth"]
  111. ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
  112. for ckpt_path in ckpt_paths:
  113. ckpt = torch.load(ckpt_path)
  114. self.assertTrue("valid" in ckpt["metrics"].keys())
  115. if __name__ == "__main__":
  116. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...