Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  1. """
  2. QAT example for Resnet18
  3. The purpose of this example is to demonstrate the usage of QAT in super_gradients.
  4. Behind the scenes, when passing enable_qat=True, a callback for QAT will be added.
  5. Once triggered, the following will happen:
  6. - The model will be rebuilt with quantized nn.modules.
  7. - The pretrained imagenet weights will be loaded to it.
  8. - We perform calibration with 2 batches from our training set (1024 samples = 8 gpus X 128 samples_per_batch).
  9. - We evaluate the calibrated model (accuracy is logged under calibrated_model_accuracy).
  10. - The calibrated checkpoint prior to QAT is saved under ckpt_calibrated_{calibration_method}.pth.
  11. - We fine tune the calibrated model for 1 epoch.
  12. Finally, once training is over- we trigger a pos-training callback that will export the ONNX files.
  13. """
  14. from super_gradients.training import Trainer, MultiGPUMode, models, dataloaders
  15. from super_gradients.training.metrics.classification_metrics import Accuracy
  16. import super_gradients
  17. from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
  18. super_gradients.init_trainer()
  19. trainer = Trainer("resnet18_qat_example",
  20. model_checkpoints_location='local',
  21. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
  22. train_loader = dataloaders.imagenet_train()
  23. valid_loader = dataloaders.imagenet_val()
  24. model = models.get("resnet18", pretrained_weights="imagenet")
  25. train_params = {"max_epochs": 1,
  26. "lr_mode": "step",
  27. "optimizer": "SGD",
  28. "lr_updates": [],
  29. "lr_decay_factor": 0.1,
  30. "initial_lr": 0.001, "loss": "cross_entropy",
  31. "train_metrics_list": [Accuracy()],
  32. "valid_metrics_list": [Accuracy()],
  33. "loss_logging_items_names": ["Loss"],
  34. "metric_to_watch": "Accuracy",
  35. "greater_metric_to_watch_is_better": True,
  36. "average_best_models": False,
  37. "enable_qat": True,
  38. "qat_params": {
  39. "start_epoch": 0, # first epoch for quantization aware training.
  40. "quant_modules_calib_method": "percentile",
  41. # statistics method for amax computation (one of [percentile, mse, entropy, max]).
  42. "calibrate": True, # whether to perform calibration.
  43. "num_calib_batches": 2, # number of batches to collect the statistics from.
  44. "percentile": 99.99 # percentile value to use when Trainer,
  45. },
  46. "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
  47. }
  48. trainer.train(model=model, training_params=train_params, train_loader=train_loader, valid_loader=valid_loader)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file