Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

trainer_test.py 6.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
  1. import shutil
  2. import unittest
  3. import super_gradients
  4. import torch
  5. import os
  6. from super_gradients import SgModel, ClassificationTestDatasetInterface
  7. from super_gradients.training.metrics import Accuracy, Top5
  8. class TestTrainer(unittest.TestCase):
  9. @classmethod
  10. def setUp(cls):
  11. super_gradients.init_trainer()
  12. # NAMES FOR THE EXPERIMENTS TO LATER DELETE
  13. cls.folder_names = ['test_train', 'test_save_load', 'test_load_w', 'test_load_w2',
  14. 'test_load_w3', 'test_checkpoint_content', 'analyze']
  15. cls.training_params = {"max_epochs": 1,
  16. "silent_mode": True,
  17. "lr_decay_factor": 0.1,
  18. "initial_lr": 0.1,
  19. "lr_updates": [4],
  20. "lr_mode": "step",
  21. "loss": "cross_entropy", "train_metrics_list": [Accuracy(), Top5()],
  22. "valid_metrics_list": [Accuracy(), Top5()],
  23. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
  24. "greater_metric_to_watch_is_better": True}
  25. @classmethod
  26. def tearDownClass(cls) -> None:
  27. # ERASE ALL THE FOLDERS THAT WERE CREATED DURING THIS TEST
  28. for folder in cls.folder_names:
  29. if os.path.isdir(os.path.join('checkpoints', folder)):
  30. shutil.rmtree(os.path.join('checkpoints', folder))
  31. @staticmethod
  32. def get_classification_trainer(name=''):
  33. model = SgModel(name, model_checkpoints_location='local')
  34. dataset_params = {"batch_size": 4}
  35. dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
  36. model.connect_dataset_interface(dataset)
  37. model.build_model("resnet18_cifar", load_checkpoint=False)
  38. return model
  39. def test_train(self):
  40. model = self.get_classification_trainer(self.folder_names[0])
  41. model.train(training_params=self.training_params)
  42. def test_save_load(self):
  43. model = self.get_classification_trainer(self.folder_names[1])
  44. model.train(training_params=self.training_params)
  45. model.build_model("resnet18_cifar", load_checkpoint=True)
  46. def test_load_only_weights_from_ckpt(self):
  47. # Create a checkpoint with 100% accuracy
  48. model = self.get_classification_trainer(self.folder_names[2])
  49. params = self.training_params.copy()
  50. params['max_epochs'] = 3
  51. model.train(training_params=params)
  52. # Build a model that continues the training
  53. model = self.get_classification_trainer(self.folder_names[3])
  54. model.build_model('resnet18_cifar', load_checkpoint=True, source_ckpt_folder_name=self.folder_names[2],
  55. load_weights_only=False)
  56. self.assertTrue(model.best_metric > -1)
  57. self.assertTrue(model.start_epoch != 0)
  58. # start_epoch is not initialized, adding to max_epochs
  59. self.training_params['max_epochs'] += 3
  60. model.train(training_params=self.training_params)
  61. # Build a model that loads the weights and starts from scratch
  62. model = self.get_classification_trainer(self.folder_names[4])
  63. model.build_model('resnet18_cifar', load_checkpoint=True, source_ckpt_folder_name=self.folder_names[2],
  64. load_weights_only=True)
  65. self.assertTrue(model.best_metric == -1)
  66. self.assertTrue(model.start_epoch == 0)
  67. self.training_params['max_epochs'] += 3
  68. model.train(training_params=self.training_params)
  69. def test_checkpoint_content(self):
  70. """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
  71. model = self.get_classification_trainer(self.folder_names[5])
  72. params = self.training_params.copy()
  73. params["save_ckpt_epoch_list"] = [1]
  74. model.train(training_params=params)
  75. ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
  76. ckpt_paths = [os.path.join(model.checkpoints_dir_path, suf) for suf in ckpt_filename]
  77. for ckpt_path in ckpt_paths:
  78. ckpt = torch.load(ckpt_path)
  79. self.assertListEqual(['net', 'acc', 'epoch', 'optimizer_state_dict', 'scaler_state_dict'],
  80. list(ckpt.keys()))
  81. model.save_checkpoint()
  82. weights_only = torch.load(os.path.join(model.checkpoints_dir_path, 'ckpt_latest_weights_only.pth'))
  83. self.assertListEqual(['net'], list(weights_only.keys()))
  84. def test_compute_model_runtime(self):
  85. model = self.get_classification_trainer(self.folder_names[6])
  86. model.compute_model_runtime()
  87. model.compute_model_runtime(batch_sizes=1, input_dims=(3, 224, 224), verbose=False)
  88. model.compute_model_runtime(batch_sizes=[1, 2, 3], verbose=True)
  89. # VERIFY MODEL RETURNS TO PREVIOUS TRAINING MODE
  90. model.net.train()
  91. model.compute_model_runtime(batch_sizes=1, verbose=False)
  92. assert model.net.training, 'MODEL WAS SET TO eval DURING compute_model_runtime, BUT DIDN\'t RETURN TO PREVIOUS'
  93. model.net.eval()
  94. model.compute_model_runtime(batch_sizes=1, verbose=False)
  95. assert not model.net.training, 'MODEL WAS SET TO eval DURING compute_model_runtime, BUT RETURNED TO TRAINING'
  96. # THESE SHOULD HANDLE THE EXCEPTION OF CUDA OUT OF MEMORY
  97. if torch.cuda.is_available():
  98. model._switch_device('cuda')
  99. model.compute_model_runtime(batch_sizes=10000, verbose=False, input_dims=(3, 224, 224))
  100. model.compute_model_runtime(batch_sizes=[10000, 10, 50, 100, 1000, 5000], verbose=True)
  101. def test_predict(self):
  102. model = self.get_classification_trainer(self.folder_names[6])
  103. inputs = torch.randn((5, 3, 32, 32))
  104. targets = torch.randint(0, 5, (5, 1))
  105. model.predict(inputs=inputs, targets=targets)
  106. model.predict(inputs=inputs, targets=targets, half=True)
  107. model.predict(inputs=inputs, targets=targets, half=False, verbose=True)
  108. if __name__ == '__main__':
  109. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...