Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#274 Remove all elasticsearch references

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:feature/LAB-0000_remove_elasticsearch_references
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. import unittest
  2. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ClassificationTestDatasetInterface
  3. from super_gradients.training import SgModel, MultiGPUMode
  4. from super_gradients.training.metrics.classification_metrics import Accuracy
  5. import os
  6. from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
  7. class QATIntegrationTest(unittest.TestCase):
  8. def _get_trainer(self, experiment_name):
  9. dataset_params = {"batch_size": 10}
  10. dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
  11. model = SgModel(experiment_name,
  12. model_checkpoints_location='local',
  13. multi_gpu=MultiGPUMode.OFF)
  14. model.connect_dataset_interface(dataset)
  15. model.build_model("resnet18", checkpoint_params={"pretrained_weights": "imagenet"})
  16. return model
  17. def _get_train_params(self, qat_params):
  18. train_params = {"max_epochs": 2,
  19. "lr_mode": "step",
  20. "optimizer": "SGD",
  21. "lr_updates": [],
  22. "lr_decay_factor": 0.1,
  23. "initial_lr": 0.001, "loss": "cross_entropy",
  24. "train_metrics_list": [Accuracy()],
  25. "valid_metrics_list": [Accuracy()],
  26. "loss_logging_items_names": ["Loss"],
  27. "metric_to_watch": "Accuracy",
  28. "greater_metric_to_watch_is_better": True,
  29. "average_best_models": False,
  30. "enable_qat": True,
  31. "qat_params": qat_params,
  32. "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
  33. }
  34. return train_params
  35. def test_qat_from_start(self):
  36. model = self._get_trainer("test_qat_from_start")
  37. train_params = self._get_train_params(qat_params={
  38. "start_epoch": 0,
  39. "quant_modules_calib_method": "percentile",
  40. "calibrate": True,
  41. "num_calib_batches": 2,
  42. "percentile": 99.99
  43. })
  44. model.train(training_params=train_params)
  45. def test_qat_transition(self):
  46. model = self._get_trainer("test_qat_transition")
  47. train_params = self._get_train_params(qat_params={
  48. "start_epoch": 1,
  49. "quant_modules_calib_method": "percentile",
  50. "calibrate": True,
  51. "num_calib_batches": 2,
  52. "percentile": 99.99
  53. })
  54. model.train(training_params=train_params)
  55. def test_qat_from_calibrated_ckpt(self):
  56. model = self._get_trainer("generate_calibrated_model")
  57. train_params = self._get_train_params(qat_params={
  58. "start_epoch": 0,
  59. "quant_modules_calib_method": "percentile",
  60. "calibrate": True,
  61. "num_calib_batches": 2,
  62. "percentile": 99.99
  63. })
  64. model.train(training_params=train_params)
  65. calibrated_model_path = os.path.join(model.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
  66. model = self._get_trainer("test_qat_from_calibrated_ckpt")
  67. train_params = self._get_train_params(qat_params={
  68. "start_epoch": 0,
  69. "quant_modules_calib_method": "percentile",
  70. "calibrate": False,
  71. "calibrated_model_path": calibrated_model_path,
  72. "num_calib_batches": 2,
  73. "percentile": 99.99
  74. })
  75. model.train(training_params=train_params)
  76. if __name__ == '__main__':
  77. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file