Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#869 Add DagsHub Logger to Super Gradients

Merged
Ghost merged 1 commits into Deci-AI:master from timho102003:dagshub_logger
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  1. import os
  2. from typing import Union, Tuple
  3. import copy
  4. import hydra
  5. import torch.cuda
  6. from omegaconf import DictConfig
  7. from omegaconf import OmegaConf
  8. from torch import nn
  9. from super_gradients.common.abstractions.abstract_logger import get_logger
  10. from super_gradients.common.environment.device_utils import device_config
  11. from super_gradients.training import utils as core_utils, models, dataloaders
  12. from super_gradients.training.sg_trainer import Trainer
  13. from super_gradients.training.utils import get_param
  14. from super_gradients.training.utils.distributed_training_utils import setup_device
  15. from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches
  16. logger = get_logger(__name__)
  17. try:
  18. from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
  19. from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
  20. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
  21. _imported_pytorch_quantization_failure = None
  22. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  23. logger.debug("Failed to import pytorch_quantization:")
  24. logger.debug(import_err)
  25. _imported_pytorch_quantization_failure = import_err
  26. class QATTrainer(Trainer):
  27. @classmethod
  28. def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
  29. """
  30. Perform quantization aware training (QAT) according to a recipe configuration.
  31. This method will instantiate all the objects specified in the recipe, build and quantize the model,
  32. and calibrate the quantized model. The resulting quantized model and the output of the trainer.train()
  33. method will be returned.
  34. The quantized model will be exported to ONNX along with other checkpoints.
  35. :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary.
  36. :return: A tuple containing the quantized model and the output of trainer.train() method.
  37. :rtype: Tuple[nn.Module, Tuple]
  38. :raises ValueError: If the recipe does not have the required key `quantization_params` or
  39. `checkpoint_params.checkpoint_path` in it.
  40. :raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1.
  41. :raises ImportError: If pytorch-quantization import was unsuccessful
  42. """
  43. if _imported_pytorch_quantization_failure is not None:
  44. raise _imported_pytorch_quantization_failure
  45. # INSTANTIATE ALL OBJECTS IN CFG
  46. cfg = hydra.utils.instantiate(cfg)
  47. # TRIGGER CFG MODIFYING CALLBACKS
  48. cfg = cls._trigger_cfg_modifying_callbacks(cfg)
  49. if "quantization_params" not in cfg:
  50. raise ValueError("Your recipe does not have quantization_params. Add them to use QAT.")
  51. if "checkpoint_path" not in cfg.checkpoint_params:
  52. raise ValueError("Starting checkpoint is a must for QAT finetuning.")
  53. num_gpus = core_utils.get_param(cfg, "num_gpus")
  54. multi_gpu = core_utils.get_param(cfg, "multi_gpu")
  55. device = core_utils.get_param(cfg, "device")
  56. if num_gpus != 1:
  57. raise NotImplementedError(
  58. f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1"
  59. )
  60. setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus)
  61. # INSTANTIATE DATA LOADERS
  62. train_dataloader = dataloaders.get(
  63. name=get_param(cfg, "train_dataloader"),
  64. dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params),
  65. dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params),
  66. )
  67. val_dataloader = dataloaders.get(
  68. name=get_param(cfg, "val_dataloader"),
  69. dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params),
  70. dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params),
  71. )
  72. if "calib_dataloader" in cfg:
  73. calib_dataloader_name = get_param(cfg, "calib_dataloader")
  74. calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params)
  75. calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params)
  76. else:
  77. calib_dataloader_name = get_param(cfg, "train_dataloader")
  78. calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params)
  79. calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params)
  80. # if we use whole dataloader for calibration, don't shuffle it
  81. # HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results
  82. # if we use several batches, we don't want it to be from one class if it's sequential in dataloader
  83. # model is in eval mode, so BNs will not be affected
  84. calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None
  85. # we don't need training transforms during calibration, distribution of activations will be skewed
  86. calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms
  87. calib_dataloader = dataloaders.get(
  88. name=calib_dataloader_name,
  89. dataset_params=calib_dataset_params,
  90. dataloader_params=calib_dataloader_params,
  91. )
  92. # BUILD MODEL
  93. model = models.get(
  94. model_name=cfg.arch_params.get("model_name", None) or cfg.architecture,
  95. num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes,
  96. arch_params=cfg.arch_params,
  97. strict_load=cfg.checkpoint_params.strict_load,
  98. pretrained_weights=cfg.checkpoint_params.pretrained_weights,
  99. checkpoint_path=cfg.checkpoint_params.checkpoint_path,
  100. load_backbone=False,
  101. )
  102. model.to(device_config.device)
  103. # QUANTIZE MODEL
  104. model.eval()
  105. fuse_repvgg_blocks_residual_branches(model)
  106. q_util = SelectiveQuantizer(
  107. default_quant_modules_calibrator_weights=cfg.quantization_params.selective_quantizer_params.calibrator_w,
  108. default_quant_modules_calibrator_inputs=cfg.quantization_params.selective_quantizer_params.calibrator_i,
  109. default_per_channel_quant_weights=cfg.quantization_params.selective_quantizer_params.per_channel,
  110. default_learn_amax=cfg.quantization_params.selective_quantizer_params.learn_amax,
  111. verbose=cfg.quantization_params.calib_params.verbose,
  112. )
  113. q_util.register_skip_quantization(layer_names=cfg.quantization_params.selective_quantizer_params.skip_modules)
  114. q_util.quantize_module(model)
  115. # CALIBRATE MODEL
  116. logger.info("Calibrating model...")
  117. calibrator = QuantizationCalibrator(
  118. verbose=cfg.quantization_params.calib_params.verbose,
  119. torch_hist=True,
  120. )
  121. calibrator.calibrate_model(
  122. model,
  123. method=cfg.quantization_params.calib_params.histogram_calib_method,
  124. calib_data_loader=calib_dataloader,
  125. num_calib_batches=cfg.quantization_params.calib_params.num_calib_batches or len(train_dataloader),
  126. percentile=get_param(cfg.quantization_params.calib_params, "percentile", 99.99),
  127. )
  128. calibrator.reset_calibrators(model) # release memory taken by calibrators
  129. # VALIDATE PTQ MODEL AND PRINT SUMMARY
  130. logger.info("Validating PTQ model...")
  131. trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
  132. valid_metrics_dict = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=cfg.training_hyperparams.valid_metrics_list)
  133. results = ["PTQ Model Validation Results"]
  134. results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
  135. logger.info("\n".join(results))
  136. # TRAIN
  137. if cfg.quantization_params.ptq_only:
  138. logger.info("cfg.quantization_params.ptq_only=True. Performing PTQ only!")
  139. suffix = "ptq"
  140. res = None
  141. else:
  142. model.train()
  143. recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
  144. trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
  145. torch.cuda.empty_cache()
  146. res = trainer.train(
  147. model=model,
  148. train_loader=train_dataloader,
  149. valid_loader=val_dataloader,
  150. training_params=cfg.training_hyperparams,
  151. additional_configs_to_log=recipe_logged_cfg,
  152. )
  153. suffix = "qat"
  154. # EXPORT QUANTIZED MODEL TO ONNX
  155. input_shape = next(iter(val_dataloader))[0].shape
  156. os.makedirs(trainer.checkpoints_dir_path, exist_ok=True)
  157. qdq_onnx_path = os.path.join(trainer.checkpoints_dir_path, f"{cfg.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx")
  158. # TODO: modify SG's convert_to_onnx for quantized models and use it instead
  159. export_quantized_module_to_onnx(
  160. model=model.cpu(),
  161. onnx_filename=qdq_onnx_path,
  162. input_shape=input_shape,
  163. input_size=input_shape,
  164. train=False,
  165. )
  166. logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}")
  167. return model, res
Discard
Tip!

Press p or to see the previous file or, n or to see the next file