Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  1. import argparse
  2. from torch import nn
  3. import super_gradients
  4. from super_gradients import Trainer
  5. from super_gradients.modules.quantization.resnet_bottleneck import QuantBottleneck as sg_QuantizedBottleneck
  6. from super_gradients.training import MultiGPUMode
  7. from super_gradients.training import models as sg_models
  8. from super_gradients.training.dataloaders import imagenet_train, imagenet_val
  9. from super_gradients.training.metrics import Accuracy, Top5
  10. from super_gradients.training.metrics.metric_utils import get_metrics_dict
  11. from super_gradients.training.models.classification_models.resnet import Bottleneck
  12. from super_gradients.training.models.classification_models.resnet import Bottleneck as sg_Bottleneck
  13. from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
  14. from super_gradients.training.utils.quantization.core import QuantizedMetadata
  15. from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
  16. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
  17. def naive_quantize(model: nn.Module):
  18. q_util = SelectiveQuantizer(
  19. default_quant_modules_calib_method_weights="max",
  20. default_quant_modules_calib_method_inputs="percentile",
  21. default_per_channel_quant_weights=True,
  22. default_learn_amax=False,
  23. )
  24. # SG already registers non-naive QuantBottleneck as in selective_quantize() down there, pop it for the sake of example
  25. if Bottleneck in q_util.mapping_instructions:
  26. q_util.mapping_instructions.pop(Bottleneck)
  27. q_util.quantize_module(model)
  28. return model
  29. def selective_quantize(model: nn.Module):
  30. mappings = {
  31. sg_Bottleneck: QuantizedMetadata(
  32. float_source=sg_Bottleneck,
  33. quantized_target_class=sg_QuantizedBottleneck,
  34. action=QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
  35. ),
  36. }
  37. sq_util = SelectiveQuantizer(
  38. custom_mappings=mappings,
  39. default_quant_modules_calib_method_weights="max",
  40. default_quant_modules_calib_method_inputs="percentile",
  41. default_per_channel_quant_weights=True,
  42. default_learn_amax=False,
  43. )
  44. sq_util.quantize_module(model)
  45. return model
  46. def sg_vanilla_resnet50():
  47. return sg_models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
  48. def sg_naive_qdq_resnet50():
  49. return naive_quantize(sg_vanilla_resnet50())
  50. def sg_selective_qdq_resnet50():
  51. return selective_quantize(sg_vanilla_resnet50())
  52. models = {
  53. "sg_vanilla_resnet50": sg_vanilla_resnet50,
  54. "sg_naive_qdq_resnet50": sg_naive_qdq_resnet50,
  55. "sg_selective_qdq_resnet50": sg_selective_qdq_resnet50,
  56. }
  57. if __name__ == "__main__":
  58. parser = argparse.ArgumentParser()
  59. super_gradients.init_trainer()
  60. parser.add_argument("--max_epochs", type=int, default=10)
  61. parser.add_argument("--lr", type=float, default=0.001)
  62. parser.add_argument("--batch", type=int, default=128)
  63. parser.add_argument("--model_name", type=str)
  64. parser.add_argument("--calibrate", action="store_true")
  65. args, _ = parser.parse_known_args()
  66. train_params = {
  67. "max_epochs": args.max_epochs,
  68. "initial_lr": args.lr,
  69. "optimizer": "SGD",
  70. "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9, "nesterov": True},
  71. "loss": "cross_entropy",
  72. "train_metrics_list": [Accuracy(), Top5()],
  73. "valid_metrics_list": [Accuracy(), Top5()],
  74. "test_metrics_list": [Accuracy(), Top5()],
  75. "loss_logging_items_names": ["Loss"],
  76. "metric_to_watch": "Accuracy",
  77. "greater_metric_to_watch_is_better": True,
  78. }
  79. trainer = Trainer(experiment_name=args.model_name, multi_gpu=MultiGPUMode.OFF, device="cuda")
  80. train_dataloader = imagenet_train(dataloader_params={"batch_size": args.batch, "shuffle": True})
  81. val_dataloader = imagenet_val(dataloader_params={"batch_size": args.batch, "shuffle": True, "drop_last": True})
  82. model = models[args.model_name]().cuda()
  83. if args.calibrate:
  84. calibrator = QuantizationCalibrator(verbose=True)
  85. calibrator.calibrate_model(model, method="percentile", calib_data_loader=train_dataloader, num_calib_batches=1024 // args.batch or 1)
  86. trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)
  87. val_results_tuple = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
  88. valid_metrics_dict = get_metrics_dict(val_results_tuple, trainer.test_metrics, trainer.loss_logging_items_names)
  89. export_quantized_module_to_onnx(model=model, onnx_filename=f"{args.model_name}.onnx", input_shape=(args.batch, 3, 224, 224))
  90. print(f"FINAL ACCURACY: {valid_metrics_dict['Accuracy'].cpu().item()}")
Discard
Tip!

Press p or to see the previous file or, n or to see the next file