Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#304 Feature/sg 000 qat

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_QAT
23 changed files with 1818 additions and 548 deletions
  1. 1
    1
      .circleci/config.yml
  2. 3
    1
      requirements.txt
  3. 3
    1
      setup.py
  4. 0
    0
      src/super_gradients/examples/quantization/__init__.py
  5. 31
    0
      src/super_gradients/examples/quantization/non_default_calibrators_example.py
  6. 69
    0
      src/super_gradients/examples/quantization/ptq_e2e_example.py
  7. 61
    0
      src/super_gradients/examples/quantization/register_quantization_mapping_with_decorator_example.py
  8. 104
    0
      src/super_gradients/examples/quantization/resnet_qat_example.py
  9. 34
    0
      src/super_gradients/examples/quantization/skipping_quantization_example.py
  10. 31
    0
      src/super_gradients/examples/quantization/vanilla_quantize_all_example.py
  11. 0
    59
      src/super_gradients/examples/resnet_qat/resnet_qat_example.py
  12. 3
    0
      src/super_gradients/modules/quantization/__init__.py
  13. 27
    0
      src/super_gradients/modules/quantization/resnet_bottleneck.py
  14. 3
    5
      src/super_gradients/training/sg_trainer/sg_trainer.py
  15. 9
    0
      src/super_gradients/training/utils/quantization/__init__.py
  16. 124
    0
      src/super_gradients/training/utils/quantization/calibrator.py
  17. 178
    0
      src/super_gradients/training/utils/quantization/core.py
  18. 40
    0
      src/super_gradients/training/utils/quantization/export.py
  19. 286
    0
      src/super_gradients/training/utils/quantization/selective_quantization_utils.py
  20. 0
    388
      src/super_gradients/training/utils/quantization_utils.py
  21. 2
    0
      tests/deci_core_unit_test_suite_runner.py
  22. 0
    93
      tests/integration_tests/qat_integration_test.py
  23. 809
    0
      tests/unit_tests/quantization_utility_tests.py
@@ -72,7 +72,7 @@ jobs:
             python3 -m venv venv
             python3 -m venv venv
             . venv/bin/activate
             . venv/bin/activate
             python3 -m pip install pip==22.0.4
             python3 -m pip install pip==22.0.4
-            cat requirements.txt | cut -f1 -d"#" | xargs -n 1 -L 1 pip install --progress-bar off
+            cat requirements.txt | cut -f1 -d"#" | grep "^[^--;]" | xargs -n 1 -L 1 pip install --progress-bar off --extra-index-url https://pypi.ngc.nvidia.com
       - run:
       - run:
           name: edit package version
           name: edit package version
           command: |
           command: |
Discard
@@ -30,4 +30,6 @@ packaging>=20.4
 # not directly required, pinned by Snyk to avoid a vulnerability
 # not directly required, pinned by Snyk to avoid a vulnerability
 wheel>=0.38.0
 wheel>=0.38.0
 # not directly required, pinned by Snyk to avoid a vulnerability
 # not directly required, pinned by Snyk to avoid a vulnerability
-pygments>=2.7.4 
+pygments>=2.7.4
+--extra-index-url https://pypi.ngc.nvidia.com
+pytorch-quantization==2.1.2
Discard
@@ -21,7 +21,8 @@ def readme():
 
 
 def get_requirements():
 def get_requirements():
     with open(REQ_LOCATION, encoding="utf-8") as f:
     with open(REQ_LOCATION, encoding="utf-8") as f:
-        return f.read().splitlines()
+        requirements = f.read().splitlines()
+        return [r for r in requirements if not r.startswith("--") and not r.startswith("#")]
 
 
 
 
 def get_pro_requirements():
 def get_pro_requirements():
@@ -45,6 +46,7 @@ setup(
     install_requires=get_requirements(),
     install_requires=get_requirements(),
     packages=find_packages(where="./src"),
     packages=find_packages(where="./src"),
     package_dir={"": "src"},
     package_dir={"": "src"},
+    dependency_links=["https://pypi.ngc.nvidia.com"],
     package_data={
     package_data={
         "super_gradients.recipes": ["*.yaml", "**/*.yaml"],
         "super_gradients.recipes": ["*.yaml", "**/*.yaml"],
         "super_gradients.common": ["auto_logging/auto_logging_conf.json"],
         "super_gradients.common": ["auto_logging/auto_logging_conf.json"],
Discard
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  1. import torch
  2. from torch import nn
  3. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
  4. def non_default_calibrators_example():
  5. class MyModel(nn.Module):
  6. def __init__(self) -> None:
  7. super().__init__()
  8. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  9. def forward(self, x):
  10. return self.conv1(x)
  11. module = MyModel()
  12. # Initialize the quantization utility, with different calibrators, and quantize the module
  13. q_util = SelectiveQuantizer(default_quant_modules_calib_method="percentile", default_per_channel_quant_modules=False)
  14. q_util.quantize_module(module)
  15. print(module) # You should expect to see QuantConv2d, with Histogram calibrators
  16. x = torch.rand(1, 3, 32, 32)
  17. with torch.no_grad():
  18. y = module(x)
  19. torch.testing.assert_close(y.size(), (1, 8, 32, 32))
  20. if __name__ == "__main__":
  21. non_default_calibrators_example()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
  1. import torch
  2. from pytorch_quantization import nn as quant_nn
  3. from torch import nn
  4. from super_gradients.training.dataloaders import cifar10_train
  5. from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
  6. from super_gradients.training.utils.quantization.core import SGQuantMixin
  7. from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
  8. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
  9. def e2e_example():
  10. class MyBlock(nn.Module):
  11. def __init__(self, in_feats, out_feats) -> None:
  12. super().__init__()
  13. self.in_feats = in_feats
  14. self.out_feats = out_feats
  15. self.flatten = nn.Flatten()
  16. self.linear = nn.Linear(in_feats, out_feats)
  17. def forward(self, x):
  18. return self.linear(self.flatten(x))
  19. class MyQuantizedBlock(SGQuantMixin):
  20. def __init__(self, in_feats, out_feats) -> None:
  21. super().__init__()
  22. self.flatten = nn.Flatten()
  23. self.linear = quant_nn.QuantLinear(in_feats, out_feats)
  24. def forward(self, x):
  25. return self.linear(self.flatten(x))
  26. class MyModel(nn.Module):
  27. def __init__(self, res, n_classes) -> None:
  28. super().__init__()
  29. self.my_block = MyBlock(3 * (res**2), n_classes)
  30. def forward(self, x):
  31. return self.my_block(x)
  32. res = 32
  33. n_clss = 10
  34. module = MyModel(res, n_clss)
  35. # QUANTIZE
  36. q_util = SelectiveQuantizer()
  37. q_util.register_quantization_mapping(layer_names={"my_block"}, quantized_target_class=MyQuantizedBlock)
  38. q_util.quantize_module(module)
  39. # CALIBRATE (PTQ)
  40. train_loader = cifar10_train()
  41. calib = QuantizationCalibrator()
  42. calib.calibrate_model(module, method=q_util.default_quant_modules_calib_method, calib_data_loader=train_loader)
  43. module.cuda()
  44. # SANITY
  45. x = torch.rand(1, 3, res, res, device="cuda")
  46. with torch.no_grad():
  47. y = module(x)
  48. torch.testing.assert_close(y.size(), (1, n_clss))
  49. print(module)
  50. # EXPORT TO ONNX
  51. export_quantized_module_to_onnx(module, "my_quantized_model.onnx", input_shape=(1, 3, res, res))
  52. if __name__ == "__main__":
  53. e2e_example()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import torch
  2. from pytorch_quantization import nn as quant_nn
  3. from torch import nn
  4. from super_gradients.training.utils.quantization.core import SGQuantMixin
  5. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer, register_quantized_module
  6. def register_quantization_mapping_with_decorator_example():
  7. # ARRANGE
  8. class MyBlock(nn.Module):
  9. def __init__(self, in_feats, out_feats) -> None:
  10. super().__init__()
  11. self.in_feats = in_feats
  12. self.out_feats = out_feats
  13. self.flatten = nn.Flatten()
  14. self.linear = nn.Linear(in_feats, out_feats)
  15. def forward(self, x):
  16. return self.linear(self.flatten(x))
  17. @register_quantized_module(float_source=MyBlock)
  18. class MyQuantizedBlock(SGQuantMixin):
  19. def __init__(self, in_feats, out_feats) -> None:
  20. super().__init__()
  21. self.flatten = nn.Flatten()
  22. self.linear = quant_nn.QuantLinear(in_feats, out_feats)
  23. def forward(self, x):
  24. return self.linear(self.flatten(x))
  25. class MyModel(nn.Module):
  26. def __init__(self, res, n_classes) -> None:
  27. super().__init__()
  28. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  29. self.my_block = MyBlock(4 * (res**2), n_classes)
  30. def forward(self, x):
  31. y = self.conv(x)
  32. return self.my_block(y)
  33. res = 32
  34. n_clss = 10
  35. module = MyModel(res, n_clss)
  36. # TEST
  37. q_util = SelectiveQuantizer()
  38. q_util.quantize_module(module)
  39. x = torch.rand(1, 3, res, res)
  40. print(module)
  41. # ASSERT
  42. with torch.no_grad():
  43. y = module(x)
  44. torch.testing.assert_close(y.size(), (1, n_clss))
  45. if __name__ == "__main__":
  46. register_quantization_mapping_with_decorator_example()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  1. import argparse
  2. from torch import nn
  3. import super_gradients
  4. from super_gradients import Trainer
  5. from super_gradients.training import MultiGPUMode
  6. from super_gradients.training import models as sg_models
  7. from super_gradients.training.dataloaders import imagenet_train, imagenet_val
  8. from super_gradients.training.metrics import Accuracy, Top5
  9. from super_gradients.training.metrics.metric_utils import get_metrics_dict
  10. from super_gradients.training.models.classification_models.resnet import Bottleneck as sg_Bottleneck
  11. from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
  12. from super_gradients.training.utils.quantization.core import QuantizedMetadata
  13. from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
  14. from super_gradients.modules.quantization.resnet_bottleneck import QuantBottleneck as sg_QuantizedBottleneck
  15. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
  16. def naive_quantize(model: nn.Module):
  17. q_util = SelectiveQuantizer(default_quant_modules_calib_method="max", default_per_channel_quant_modules=True)
  18. q_util.quantize_module(model)
  19. return model
  20. def selective_quantize(model: nn.Module):
  21. mappings = {
  22. sg_Bottleneck: QuantizedMetadata(
  23. float_source=sg_Bottleneck,
  24. quantized_target_class=sg_QuantizedBottleneck,
  25. action=QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
  26. ),
  27. }
  28. sq_util = SelectiveQuantizer(custom_mappings=mappings, default_quant_modules_calib_method="max", default_per_channel_quant_modules=True)
  29. sq_util.quantize_module(model)
  30. return model
  31. def sg_vanilla_resnet50():
  32. return sg_models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
  33. def sg_naive_qdq_resnet50():
  34. return naive_quantize(sg_vanilla_resnet50())
  35. def sg_selective_qdq_resnet50():
  36. return selective_quantize(sg_vanilla_resnet50())
  37. models = {
  38. "sg_vanilla_resnet50": sg_vanilla_resnet50,
  39. "sg_naive_qdq_resnet50": sg_naive_qdq_resnet50,
  40. "sg_selective_qdq_resnet50": sg_selective_qdq_resnet50,
  41. }
  42. if __name__ == "__main__":
  43. parser = argparse.ArgumentParser()
  44. super_gradients.init_trainer()
  45. parser.add_argument("--max_epochs", type=int, default=10)
  46. parser.add_argument("--lr", type=float, default=0.001)
  47. parser.add_argument("--batch", type=int, default=128)
  48. parser.add_argument("--model_name", type=str)
  49. parser.add_argument("--calibrate", action="store_true")
  50. args, _ = parser.parse_known_args()
  51. train_params = {
  52. "max_epochs": args.max_epochs,
  53. "initial_lr": args.lr,
  54. "optimizer": "SGD",
  55. "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9, "nesterov": True},
  56. "loss": "cross_entropy",
  57. "train_metrics_list": [Accuracy(), Top5()],
  58. "valid_metrics_list": [Accuracy(), Top5()],
  59. "test_metrics_list": [Accuracy(), Top5()],
  60. "loss_logging_items_names": ["Loss"],
  61. "metric_to_watch": "Accuracy",
  62. "greater_metric_to_watch_is_better": True,
  63. }
  64. trainer = Trainer(experiment_name=args.model_name, multi_gpu=MultiGPUMode.OFF, device="cuda")
  65. train_dataloader = imagenet_train(dataloader_params={"batch_size": args.batch, "shuffle": True})
  66. val_dataloader = imagenet_val(dataloader_params={"batch_size": args.batch, "shuffle": True, "drop_last": True})
  67. model = models[args.model_name]().cuda()
  68. if args.calibrate:
  69. calibrator = QuantizationCalibrator(verbose=False)
  70. calibrator.calibrate_model(model, method="max", calib_data_loader=train_dataloader, num_calib_batches=1024 // args.batch or 1)
  71. trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)
  72. val_results_tuple = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
  73. valid_metrics_dict = get_metrics_dict(val_results_tuple, trainer.test_metrics, trainer.loss_logging_items_names)
  74. export_quantized_module_to_onnx(model=model, onnx_filename=f"{args.model_name}.onnx", input_shape=(args.batch, 3, 224, 224))
  75. print(f"FINAL ACCURACY: {valid_metrics_dict['Accuracy'].cpu().item()}")
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  1. import torch
  2. from torch import nn
  3. from super_gradients.training.utils.quantization.core import SkipQuantization
  4. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
  5. def skipping_quantization_example():
  6. class MyModel(nn.Module):
  7. def __init__(self) -> None:
  8. super().__init__()
  9. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  10. self.conv2 = SkipQuantization(nn.Conv2d(8, 8, kernel_size=3, padding=1)) # can use the wrapper to skip
  11. def forward(self, x):
  12. return self.conv2(self.conv1(x))
  13. module = MyModel()
  14. # Initialize the quantization utility, register layers to skip, and quantize the module
  15. q_util = SelectiveQuantizer()
  16. q_util.register_skip_quantization(layer_names={"conv1"}) # can also configure skip by layer names
  17. q_util.quantize_module(module)
  18. print(module) # You should expect to see Conv2d
  19. x = torch.rand(1, 3, 32, 32)
  20. with torch.no_grad():
  21. y = module(x)
  22. torch.testing.assert_close(y.size(), (1, 8, 32, 32))
  23. if __name__ == "__main__":
  24. skipping_quantization_example()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  1. import torch
  2. from torch import nn
  3. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
  4. def vanilla_quantize_all_example():
  5. class MyModel(nn.Module):
  6. def __init__(self) -> None:
  7. super().__init__()
  8. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  9. def forward(self, x):
  10. return self.conv1(x)
  11. module = MyModel()
  12. # Initialize the quantization utility, and quantize the module
  13. q_util = SelectiveQuantizer()
  14. q_util.quantize_module(module)
  15. print(module) # You should expect to see QuantConv2d
  16. x = torch.rand(1, 3, 32, 32)
  17. with torch.no_grad():
  18. y = module(x)
  19. torch.testing.assert_close(y.size(), (1, 8, 32, 32))
  20. if __name__ == "__main__":
  21. vanilla_quantize_all_example()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  1. """
  2. QAT example for Resnet18
  3. The purpose of this example is to demonstrate the usage of QAT in super_gradients.
  4. Behind the scenes, when passing enable_qat=True, a callback for QAT will be added.
  5. Once triggered, the following will happen:
  6. - The model will be rebuilt with quantized nn.modules.
  7. - The pretrained imagenet weights will be loaded to it.
  8. - We perform calibration with 2 batches from our training set (1024 samples = 8 gpus X 128 samples_per_batch).
  9. - We evaluate the calibrated model (accuracy is logged under calibrated_model_accuracy).
  10. - The calibrated checkpoint prior to QAT is saved under ckpt_calibrated_{calibration_method}.pth.
  11. - We fine tune the calibrated model for 1 epoch.
  12. Finally, once training is over- we trigger a pos-training callback that will export the ONNX files.
  13. """
  14. from super_gradients.training import Trainer, MultiGPUMode, models, dataloaders
  15. from super_gradients.training.metrics.classification_metrics import Accuracy
  16. import super_gradients
  17. from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
  18. super_gradients.init_trainer()
  19. trainer = Trainer("resnet18_qat_example",
  20. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL)
  21. train_loader = dataloaders.imagenet_train()
  22. valid_loader = dataloaders.imagenet_val()
  23. model = models.get("resnet18", pretrained_weights="imagenet")
  24. train_params = {"max_epochs": 1,
  25. "lr_mode": "step",
  26. "optimizer": "SGD",
  27. "lr_updates": [],
  28. "lr_decay_factor": 0.1,
  29. "initial_lr": 0.001, "loss": "cross_entropy",
  30. "train_metrics_list": [Accuracy()],
  31. "valid_metrics_list": [Accuracy()],
  32. "metric_to_watch": "Accuracy",
  33. "greater_metric_to_watch_is_better": True,
  34. "average_best_models": False,
  35. "enable_qat": True,
  36. "qat_params": {
  37. "start_epoch": 0, # first epoch for quantization aware training.
  38. "quant_modules_calib_method": "percentile",
  39. # statistics method for amax computation (one of [percentile, mse, entropy, max]).
  40. "calibrate": True, # whether to perform calibration.
  41. "num_calib_batches": 2, # number of batches to collect the statistics from.
  42. "percentile": 99.99 # percentile value to use when Trainer,
  43. },
  44. "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
  45. }
  46. trainer.train(model=model, training_params=train_params, train_loader=train_loader, valid_loader=valid_loader)
Discard
1
2
3
  1. from .resnet_bottleneck import QuantBottleneck
  2. __all__ = ["QuantBottleneck"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  1. from super_gradients.training.models import Bottleneck
  2. try:
  3. from pytorch_quantization import nn as quant_nn
  4. from super_gradients.training.utils.quantization.core import SGQuantMixin, QuantizedMetadata
  5. from super_gradients.training.utils.quantization.selective_quantization_utils import register_quantized_module
  6. _imported_pytorch_quantization_failure = None
  7. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  8. _imported_pytorch_quantization_failure = import_err
  9. @register_quantized_module(float_source=Bottleneck, action=QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE)
  10. class QuantBottleneck(SGQuantMixin):
  11. """
  12. we just insert quantized tensor to the shortcut (=residual) layer, so that it would be quantized
  13. NOTE: we must quantize the float instance, so the mode should be
  14. QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE
  15. """
  16. if _imported_pytorch_quantization_failure is not None:
  17. raise _imported_pytorch_quantization_failure
  18. @classmethod
  19. def from_float(cls, float_instance: Bottleneck, **kwargs):
  20. float_instance.shortcut.add_module("residual_quantizer", quant_nn.TensorQuantizer(kwargs.get("quant_desc_input")))
  21. return float_instance
Discard
@@ -36,7 +36,6 @@ from super_gradients.training import utils as core_utils, models, dataloaders
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils import sg_trainer_utils
-from super_gradients.training.utils.quantization_utils import QATCallback
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args, log_main_training_params
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, parse_args, log_main_training_params
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat, GPUModeNotSetupError
 from super_gradients.training.losses import LOSSES
 from super_gradients.training.losses import LOSSES
@@ -1045,10 +1044,9 @@ class Trainer:
         # ADD CALLBACK FOR QAT
         # ADD CALLBACK FOR QAT
         self.enable_qat = core_utils.get_param(self.training_params, "enable_qat", False)
         self.enable_qat = core_utils.get_param(self.training_params, "enable_qat", False)
         if self.enable_qat:
         if self.enable_qat:
-            self.qat_params = core_utils.get_param(self.training_params, "qat_params")
-            if self.qat_params is None:
-                raise ValueError("Must pass QAT params when enable_qat=True")
-            self.phase_callbacks.append(QATCallback(**self.qat_params))
+            raise NotImplementedError(
+                "QAT is not implemented as a plug-and-play feature yet. Please refer to examples/resnet_qat to learn how to do it manually."
+            )
 
 
         self.phase_callback_handler = CallbackHandler(callbacks=self.phase_callbacks)
         self.phase_callback_handler = CallbackHandler(callbacks=self.phase_callbacks)
 
 
Discard
1
2
3
4
5
6
7
8
9
  1. from super_gradients.common.abstractions.abstract_logger import get_logger
  2. logger = get_logger(__name__)
  3. try:
  4. from super_gradients.training.utils.quantization.core import _inject_class_methods_to_default_quant_types
  5. _inject_class_methods_to_default_quant_types()
  6. except (ImportError, NameError, ModuleNotFoundError):
  7. logger.warning("Failed to import pytorch_quantization")
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
  1. """
  2. Quantization utilities
  3. Methods are based on:
  4. https://github.com/NVIDIA/TensorRT/blob/51a4297753d3e12d0eed864be52400f429a6a94c/tools/pytorch-quantization/examples/torchvision/classification_flow.py#L385
  5. (Licensed under the Apache License, Version 2.0)
  6. """
  7. import torch
  8. from tqdm import tqdm
  9. from super_gradients.common.abstractions.abstract_logger import get_logger
  10. from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
  11. from torch.distributed import all_gather
  12. logger = get_logger(__name__)
  13. try:
  14. from pytorch_quantization import nn as quant_nn
  15. from pytorch_quantization import calib
  16. _imported_pytorch_quantization_failure = None
  17. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  18. logger.warning("Failed to import pytorch_quantization")
  19. _imported_pytorch_quantization_failure = import_err
  20. class QuantizationCalibrator:
  21. def __init__(self, verbose: bool = True) -> None:
  22. if _imported_pytorch_quantization_failure is not None:
  23. raise _imported_pytorch_quantization_failure
  24. super().__init__()
  25. self.verbose = verbose
  26. def calibrate_model(
  27. self,
  28. model: torch.nn.Module,
  29. calib_data_loader: torch.utils.data.DataLoader,
  30. method: str = "percentile",
  31. num_calib_batches: int = 2,
  32. percentile: float = 99.99,
  33. ):
  34. """
  35. Calibrates torch model with quantized modules.
  36. :param model: torch.nn.Module, model to perfrom the calibration on.
  37. :param calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset.
  38. Assumes that the first element of the tuple is the input image.
  39. :param method: str, One of [percentile, mse, entropy, max].
  40. Statistics method for amax computation of the quantized modules
  41. (Default=percentile).
  42. :param num_calib_batches: int, number of batches to collect the statistics from.
  43. :param percentile: float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
  44. Discarded when other methods are used (Default=99.99).
  45. """
  46. acceptable_methods = ["percentile", "mse", "entropy", "max"]
  47. if method in acceptable_methods:
  48. with torch.no_grad():
  49. self._collect_stats(model, calib_data_loader, num_batches=num_calib_batches)
  50. # FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS,
  51. # SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT.
  52. if method == "precentile":
  53. self._compute_amax(model, method="percentile", percentile=percentile)
  54. else:
  55. self._compute_amax(model, method=method)
  56. else:
  57. raise ValueError(f"Unsupported quantization calibration method, " f"expected one of: {'.'.join(acceptable_methods)}, however, received: {method}")
  58. def _collect_stats(self, model, data_loader, num_batches):
  59. """Feed data to the network and collect statistics"""
  60. local_rank = get_local_rank()
  61. world_size = get_world_size()
  62. device = next(model.parameters()).device
  63. # Enable calibrators
  64. self._enable_calibrators(model)
  65. # Feed data to the network for collecting stats
  66. for i, (image, *_) in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
  67. if world_size > 1:
  68. all_batches = [torch.zeros_like(image, device=device) for _ in range(world_size)]
  69. all_gather(all_batches, image.to(device=device))
  70. else:
  71. all_batches = [image]
  72. for local_image in all_batches:
  73. model(local_image.to(device=device))
  74. if i >= num_batches:
  75. break
  76. # Disable calibrators
  77. self._disable_calibrators(model)
  78. def _disable_calibrators(self, model):
  79. for name, module in model.named_modules():
  80. if isinstance(module, quant_nn.TensorQuantizer):
  81. if module._calibrator is not None:
  82. module.enable_quant()
  83. module.disable_calib()
  84. else:
  85. module.enable()
  86. def _enable_calibrators(self, model):
  87. for name, module in model.named_modules():
  88. if isinstance(module, quant_nn.TensorQuantizer):
  89. if module._calibrator is not None:
  90. module.disable_quant()
  91. module.enable_calib()
  92. else:
  93. module.disable()
  94. def _compute_amax(self, model, **kwargs):
  95. for name, module in model.named_modules():
  96. if isinstance(module, quant_nn.TensorQuantizer):
  97. if module._calibrator is not None:
  98. if isinstance(module._calibrator, calib.MaxCalibrator):
  99. module.load_calib_amax()
  100. else:
  101. module.load_calib_amax(**kwargs)
  102. if self.verbose:
  103. print(f"{name:40}: {module}")
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
  1. import inspect
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. from typing import Union, Type, Optional, Set
  5. from pytorch_quantization.nn.modules._utils import QuantMixin, QuantInputMixin
  6. from pytorch_quantization.tensor_quant import QuantDescriptor
  7. from torch import nn
  8. def _extract_init_args(cls, float_instance, ignore_init_args: Set[str] = ()):
  9. """
  10. Inspecting the __init__ args, and searching for corresponding properties from the float instance
  11. e.g., for `__init__(self, a)` the mechanism will look for `float_instance.a` and pass that value to `__init__`
  12. """
  13. required_init_params = list(inspect.signature(cls.__init__).parameters)[1:] # [0] is self
  14. if "kwargs" in required_init_params: # we don't want to search for a state named `kwargs`
  15. required_init_params.pop(required_init_params.index("kwargs"))
  16. float_instance_state = {}
  17. for p in required_init_params:
  18. if p in ignore_init_args: # ignore these args and don't pick state from the instance
  19. continue
  20. if not hasattr(float_instance, p):
  21. raise ValueError(
  22. f"{float_instance.__class__.__name__} is missing `{p}` which is required "
  23. f"in {cls.__name__}.__init__. Either override `SGQuantBase.from_float` "
  24. f"or add {p} as state for {float_instance.__class__.__name__}."
  25. )
  26. float_instance_state[p] = getattr(float_instance, p)
  27. # Edge-cases
  28. if "bias" in float_instance_state:
  29. if float_instance_state["bias"] is None: # None is the state when bias=False in torch.nn
  30. float_instance_state["bias"] = False
  31. elif not isinstance(float_instance_state["bias"], bool): # Tensor is the state when bias=True in torch.nn
  32. float_instance_state["bias"] = True
  33. # in case bias is a boolean - we don't do anything, so it is taken as-is, either True or False
  34. return float_instance_state
  35. def _from_float(cls, float_instance, ignore_init_args: Set[str] = (), **kwargs):
  36. init_params = _extract_init_args(cls, float_instance, ignore_init_args)
  37. init_params.update(**kwargs)
  38. return cls(**init_params)
  39. class SGQuantMixin(nn.Module):
  40. """
  41. A base class for user custom Quantized classes.
  42. Every Quantized class must inherit this mixin, which adds `from_float` class-method.
  43. NOTES:
  44. * the Quantized class may also inherit from the native `QuantMixin` or `QuantInputMixin`
  45. * quant descriptors (for inputs and weights) will be passed as `kwargs`. The module may ignore them if they are
  46. not necessary
  47. * the default implementation of `from_float` is inspecting the __init__ args, and searching for corresponding
  48. properties from the float instance that is passed as argument, e.g., for `__init__(self, a)`
  49. the mechanism will look for `float_instance.a` and pass that value to the `__init__` method
  50. """
  51. @classmethod
  52. def from_float(cls, float_instance, **kwargs):
  53. required_init_params = list(inspect.signature(cls.__init__).parameters)[1:] # [0] is self
  54. # if cls.__init__ has explicit `quant_desc_input` or `quant_desc_weight` - we don't search the state of the
  55. # float module, because it would not contain this state. these values are injected by the framework
  56. ignore_init_args = {"quant_desc_input", "quant_desc_weight"}.intersection(set(required_init_params))
  57. # if cls.__init__ doesn't have neither **kwargs, nor `quant_desc_input` and `quant_desc_weight`,
  58. # we should also remove these keys from the passed kwargs and make sure there's nothing more!
  59. if "kwargs" not in required_init_params:
  60. for arg in ("quant_desc_input", "quant_desc_weight"):
  61. if arg in ignore_init_args:
  62. continue
  63. kwargs.pop(arg, None) # we ignore if not existing
  64. return _from_float(cls, float_instance, ignore_init_args, **kwargs)
  65. class SkipQuantization(nn.Module):
  66. """
  67. This class wraps a float module instance, and defines that this instance will not be converted to quantized version
  68. Example:
  69. self.my_block = SkipQuantization(MyBlock(4, n_classes))
  70. """
  71. def __init__(self, module: nn.Module) -> None:
  72. super().__init__()
  73. self.float_module = module
  74. self.forward = module.forward
  75. @dataclass(init=True)
  76. class QuantizedMetadata:
  77. """
  78. This dataclass is responsible for holding the information regarding float->quantized module relation.
  79. It can be both layer-grained and module-grained, e.g.,
  80. `module.backbone.conv1 -> QuantConv2d`, `nn.Linear -> QuantLinear`, etc...
  81. Args:
  82. float_source: the name of a specific layer (e.g., `module.backbone.conv1`),
  83. or a specific type (e.g., `Conv2d`) that will be later quantized
  84. quantized_target_class: the quantized type that the source will be converted to
  85. action: how to resolve the conversion, we either:
  86. - SKIP: skip it,
  87. - UNWRAP: unwrap the instance and work with the wrapped one
  88. (i.e., we wrap with a mapper),
  89. - REPLACE: replace source with an instance of the
  90. quantized type
  91. - REPLACE_AND_RECURE: replace source with an instance of the
  92. quantized type, then try to recursively quantize the child modules of that type
  93. - RECURE_AND_REPLACE: recursively quantize the child modules, then
  94. replace source with an instance of the quantized type
  95. input_quant_descriptor: quantization descriptor for inputs (None will take the default one)
  96. weights_quant_descriptor: quantization descriptor for weights (None will take the default one)
  97. """
  98. class ReplacementAction(Enum):
  99. REPLACE = "replace"
  100. REPLACE_AND_RECURE = "replace_and_recure"
  101. RECURE_AND_REPLACE = "recure_and_replace"
  102. UNWRAP = "unwrap"
  103. SKIP = "skip"
  104. float_source: Union[str, Type]
  105. quantized_target_class: Optional[Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]]]
  106. action: ReplacementAction
  107. input_quant_descriptor: QuantDescriptor = None # default is used if None
  108. weights_quant_descriptor: QuantDescriptor = None # default is used if None
  109. def __post_init__(self):
  110. if self.action in (
  111. QuantizedMetadata.ReplacementAction.REPLACE,
  112. QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
  113. QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
  114. ):
  115. assert issubclass(self.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin))
  116. class QuantizedMapping(nn.Module):
  117. """
  118. This class wraps a float module instance, and defines a mapping from this instance to the corresponding quantized
  119. class, with relevant quant descriptors.
  120. Example:
  121. self.my_block = QuantizedMapping(float_module=MyBlock(4, n_classes), quantized_target_class=MyQuantizedBlock)
  122. """
  123. def __init__(
  124. self,
  125. *,
  126. float_module: nn.Module,
  127. quantized_target_class: Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]],
  128. action=QuantizedMetadata.ReplacementAction.REPLACE,
  129. input_quant_descriptor: QuantDescriptor = None,
  130. weights_quant_descriptor: QuantDescriptor = None,
  131. ) -> None:
  132. super().__init__()
  133. self.float_module = float_module
  134. self.quantized_target_class = quantized_target_class
  135. self.action = action
  136. self.input_quant_descriptor = input_quant_descriptor
  137. self.weights_quant_descriptor = weights_quant_descriptor
  138. self.forward = float_module.forward
  139. def _inject_class_methods_to_default_quant_types():
  140. """
  141. This is used to add `from_float` capability for the "native" pytorch-quantization (=nvidia-tensorrt) quant classes
  142. It allows SG to support these modules out of the box
  143. """
  144. import pytorch_quantization.quant_modules
  145. for quant_entry in pytorch_quantization.quant_modules._DEFAULT_QUANT_MAP:
  146. quant_cls = quant_entry.replace_mod
  147. quant_cls.from_float = classmethod(_from_float)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
  1. import torch
  2. from super_gradients.common.abstractions.abstract_logger import get_logger
  3. logger = get_logger(__name__)
  4. try:
  5. from pytorch_quantization import nn as quant_nn
  6. _imported_pytorch_quantization_failure = None
  7. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  8. logger.warning("Failed to import pytorch_quantization")
  9. _imported_pytorch_quantization_failure = import_err
  10. def export_quantized_module_to_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple, **kwargs):
  11. """
  12. Method for exporting onnx after QAT.
  13. :param model: torch.nn.Module, model to export
  14. :param onnx_filename: str, target path for the onnx file,
  15. :param input_shape: tuple, input shape (usually BCHW)
  16. """
  17. if _imported_pytorch_quantization_failure is not None:
  18. raise _imported_pytorch_quantization_failure
  19. model.eval()
  20. if hasattr(model, "prep_model_for_conversion"):
  21. model.prep_model_for_conversion(**kwargs)
  22. use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant
  23. quant_nn.TensorQuantizer.use_fb_fake_quant = True
  24. # Export ONNX for multiple batch sizes
  25. logger.info("Creating ONNX file: " + onnx_filename)
  26. dummy_input = torch.randn(input_shape, device=next(model.parameters()).device)
  27. torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=13, enable_onnx_checker=False, do_constant_folding=True)
  28. # Restore functions of quant_nn back as expected
  29. quant_nn.TensorQuantizer.use_fb_fake_quant = use_fb_fake_quant_state
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
  1. from typing import Tuple, Set, Type, Dict, Union, Callable, Optional
  2. from torch import nn
  3. from super_gradients.common.abstractions.abstract_logger import get_logger
  4. logger = get_logger(__name__)
  5. try:
  6. from pytorch_quantization.nn.modules._utils import QuantMixin, QuantInputMixin
  7. from pytorch_quantization.tensor_quant import QuantDescriptor
  8. from pytorch_quantization import nn as quant_nn
  9. from super_gradients.training.utils.quantization.core import SkipQuantization, SGQuantMixin, QuantizedMapping, QuantizedMetadata
  10. _imported_pytorch_quantization_failure = None
  11. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  12. logger.warning("Failed to import pytorch_quantization")
  13. _imported_pytorch_quantization_failure = import_err
  14. def register_quantized_module(
  15. float_source: Union[str, Type[nn.Module]],
  16. action: QuantizedMetadata.ReplacementAction = QuantizedMetadata.ReplacementAction.REPLACE,
  17. input_quant_descriptor: Optional[QuantDescriptor] = None,
  18. weights_quant_descriptor: Optional[QuantDescriptor] = None,
  19. ) -> Callable:
  20. """
  21. Decorator used to register a Quantized module as a quantized version for Float module
  22. :param action: action to perform on the float_source
  23. :param float_source: the float module type that is being registered
  24. :param input_quant_descriptor: the input quantization descriptor
  25. :param weights_quant_descriptor: the weight quantization descriptor
  26. """
  27. def decorator(quant_module: Type[SGQuantMixin]) -> Type[SGQuantMixin]:
  28. if float_source in SelectiveQuantizer.mapping_instructions:
  29. metadata = SelectiveQuantizer.mapping_instructions[float_source]
  30. raise ValueError(f"`{float_source}` is already registered with following metadata {metadata}")
  31. SelectiveQuantizer.mapping_instructions.update(
  32. {
  33. float_source: QuantizedMetadata(
  34. float_source=float_source,
  35. quantized_target_class=quant_module,
  36. input_quant_descriptor=input_quant_descriptor,
  37. weights_quant_descriptor=weights_quant_descriptor,
  38. action=action,
  39. )
  40. }
  41. )
  42. return quant_module # this is required since the decorator assigns the result to the `quant_module`
  43. return decorator
  44. class SelectiveQuantizer:
  45. """
  46. :param custom_mappings: custom mappings that extend the default mappings with extra behaviour
  47. :param default_quant_modules_calib_method: default calibration method (default='percentile')
  48. :param default_per_channel_quant_modules: whether quant modules should be per channel (default=False)
  49. """
  50. if _imported_pytorch_quantization_failure is not None:
  51. raise _imported_pytorch_quantization_failure
  52. mapping_instructions: Dict[Union[str, Type], QuantizedMetadata] = {
  53. **{
  54. float_type: QuantizedMetadata(
  55. float_source=float_type,
  56. quantized_target_class=quantized_target_class,
  57. action=QuantizedMetadata.ReplacementAction.REPLACE,
  58. )
  59. for (float_type, quantized_target_class) in [
  60. (nn.Conv1d, quant_nn.QuantConv1d),
  61. (nn.Conv2d, quant_nn.QuantConv2d),
  62. (nn.Conv3d, quant_nn.QuantConv3d),
  63. (nn.ConvTranspose1d, quant_nn.QuantConvTranspose1d),
  64. (nn.ConvTranspose2d, quant_nn.QuantConvTranspose2d),
  65. (nn.ConvTranspose3d, quant_nn.QuantConvTranspose3d),
  66. (nn.Linear, quant_nn.Linear),
  67. (nn.LSTM, quant_nn.LSTM),
  68. (nn.LSTMCell, quant_nn.LSTMCell),
  69. (nn.AvgPool1d, quant_nn.QuantAvgPool1d),
  70. (nn.AvgPool2d, quant_nn.QuantAvgPool2d),
  71. (nn.AvgPool3d, quant_nn.QuantAvgPool3d),
  72. (nn.AdaptiveAvgPool1d, quant_nn.QuantAdaptiveAvgPool1d),
  73. (nn.AdaptiveAvgPool2d, quant_nn.QuantAdaptiveAvgPool2d),
  74. (nn.AdaptiveAvgPool3d, quant_nn.QuantAdaptiveAvgPool3d),
  75. ]
  76. },
  77. SkipQuantization: QuantizedMetadata(float_source=SkipQuantization, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP),
  78. } # DEFAULT MAPPING INSTRUCTIONS
  79. def __init__(
  80. self, *, custom_mappings: dict = None, default_quant_modules_calib_method: str = "max", default_per_channel_quant_modules: bool = True
  81. ) -> None:
  82. super().__init__()
  83. self.default_quant_modules_calib_method = default_quant_modules_calib_method
  84. self.default_per_channel_quant_modules = default_per_channel_quant_modules
  85. self.mapping_instructions = self.mapping_instructions.copy()
  86. if custom_mappings is not None:
  87. self.mapping_instructions.update(custom_mappings) # OVERRIDE DEFAULT WITH CUSTOM. CUSTOM IS PRIORITIZED
  88. def _get_default_quant_descriptor(self, for_weights=False):
  89. if self.default_quant_modules_calib_method in ["percentile", "mse", "entropy"]:
  90. calib_method_type = "histogram"
  91. else:
  92. calib_method_type = "max"
  93. if self.default_per_channel_quant_modules and for_weights:
  94. return QuantDescriptor(calib_method=calib_method_type, axis=0)
  95. return QuantDescriptor(calib_method=calib_method_type)
  96. def register_skip_quantization(self, *, layer_names: Set[str]):
  97. self.mapping_instructions.update(
  98. {name: QuantizedMetadata(float_source=name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.SKIP) for name in layer_names}
  99. )
  100. def register_quantization_mapping(
  101. self, *, layer_names: Set[str], quantized_target_class: Type[SGQuantMixin], input_quant_descriptor=None, weights_quant_descriptor=None
  102. ):
  103. self.mapping_instructions.update(
  104. {
  105. name: QuantizedMetadata(
  106. float_source=name,
  107. quantized_target_class=quantized_target_class,
  108. action=QuantizedMetadata.ReplacementAction.REPLACE,
  109. input_quant_descriptor=input_quant_descriptor,
  110. weights_quant_descriptor=weights_quant_descriptor,
  111. )
  112. for name in layer_names
  113. }
  114. )
  115. def _preprocess_skips_and_custom_mappings(self, module: nn.Module, nesting: Tuple[str, ...] = ()):
  116. """
  117. This pass is done to extract layer name and mapping instructions, so that we regard to per-layer processing.
  118. Relevant layer-specific mapping instructions are either `SkipQuantization` or `QuantizedMapping`, which are then
  119. being added to the mappings
  120. """
  121. mapping_instructions = dict()
  122. for name, child_module in module.named_children():
  123. nested_name = ".".join(nesting + (name,))
  124. if isinstance(child_module, SkipQuantization):
  125. mapping_instructions[nested_name] = QuantizedMetadata(
  126. float_source=nested_name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP
  127. )
  128. if isinstance(child_module, QuantizedMapping):
  129. mapping_instructions[nested_name] = QuantizedMetadata(
  130. float_source=nested_name,
  131. quantized_target_class=child_module.quantized_target_class,
  132. input_quant_descriptor=child_module.input_quant_descriptor,
  133. weights_quant_descriptor=child_module.weights_quant_descriptor,
  134. action=child_module.action,
  135. )
  136. if isinstance(child_module, nn.Module): # recursive call
  137. mapping_instructions.update(self._preprocess_skips_and_custom_mappings(child_module, nesting + (name,)))
  138. return mapping_instructions
  139. def _instantiate_quantized_from_float(self, float_module, metadata, preserve_state_dict):
  140. base_classes = (QuantMixin, QuantInputMixin, SGQuantMixin)
  141. if not issubclass(metadata.quantized_target_class, base_classes):
  142. raise AssertionError(
  143. f"Quantization suite for {type(float_module).__name__} is invalid. "
  144. f"{metadata.quantized_target_class.__name__} must inherit one of "
  145. f"{', '.join(map(lambda _: _.__name__, base_classes))}"
  146. )
  147. # USE PROVIDED QUANT DESCRIPTORS, OR DEFAULT IF NONE PROVIDED
  148. quant_descriptors = dict()
  149. if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin)):
  150. quant_descriptors = {"quant_desc_input": metadata.input_quant_descriptor or self._get_default_quant_descriptor(for_weights=False)}
  151. if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin)):
  152. quant_descriptors.update({"quant_desc_weight": metadata.weights_quant_descriptor or self._get_default_quant_descriptor(for_weights=True)})
  153. if not hasattr(metadata.quantized_target_class, "from_float"):
  154. assert isinstance(metadata.quantized_target_class, SGQuantMixin), (
  155. f"{metadata.quantized_target_class.__name__} must inherit from " f"{SGQuantMixin.__name__}, so that it would include `from_float` class method"
  156. )
  157. q_instance = metadata.quantized_target_class.from_float(float_module, **quant_descriptors)
  158. # MOVE TENSORS TO ORIGINAL DEVICE
  159. if len(list(float_module.parameters(recurse=False))) > 0:
  160. q_instance = q_instance.to(next(float_module.parameters(recurse=False)).device)
  161. elif len(list(float_module.buffers(recurse=False))):
  162. q_instance = q_instance.to(next(float_module.buffers(recurse=False)).device)
  163. # COPY STATE DICT IF NEEDED
  164. if preserve_state_dict:
  165. q_instance.load_state_dict(float_module.state_dict(), strict=True)
  166. return q_instance
  167. def _maybe_quantize_one_layer(
  168. self,
  169. module: nn.Module,
  170. child_name: str,
  171. nesting: Tuple[str, ...],
  172. child_module: nn.Module,
  173. mapping_instructions: Dict[Union[str, Type], QuantizedMetadata],
  174. preserve_state_dict: bool,
  175. ) -> bool:
  176. """
  177. Does the heavy lifting of (maybe) quantizing a layer: creates a quantized instance based on a float instance,
  178. and replaces it in the "parent" module
  179. :param module: the module we'd like to quantize a specific layer in
  180. :param child_name: the attribute name of the layer in the module
  181. :param nesting: the current nesting we're in. Needed to find the appropriate key in the mappings
  182. :param child_module: the instance of the float module we'd like to quantize
  183. :param mapping_instructions: mapping instructions: how to quantize
  184. :param preserve_state_dict: whether to copy the state dict from the float instance to the quantized instance
  185. :return: a boolean indicates if we found a match and should not continue recursively
  186. """
  187. # if we don't have any instruction for the specific layer or the specific type - we continue
  188. # NOTE! IT IS IMPORTANT TO FIRST PROCESS THE NAME AND ONLY THEN THE TYPE
  189. if _imported_pytorch_quantization_failure is not None:
  190. raise _imported_pytorch_quantization_failure
  191. for candidate_key in (".".join(nesting + (child_name,)), type(child_module)):
  192. if candidate_key not in mapping_instructions:
  193. continue
  194. metadata: QuantizedMetadata = mapping_instructions[candidate_key]
  195. if metadata.action == QuantizedMetadata.ReplacementAction.SKIP:
  196. return True
  197. elif metadata.action == QuantizedMetadata.ReplacementAction.UNWRAP:
  198. assert isinstance(child_module, SkipQuantization)
  199. setattr(module, child_name, child_module.float_module)
  200. return True
  201. elif metadata.action in (
  202. QuantizedMetadata.ReplacementAction.REPLACE,
  203. QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
  204. QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
  205. ):
  206. if isinstance(child_module, QuantizedMapping): # UNWRAP MAPPING
  207. child_module = child_module.float_module
  208. q_instance: nn.Module = self._instantiate_quantized_from_float(
  209. float_module=child_module, metadata=metadata, preserve_state_dict=preserve_state_dict
  210. )
  211. # ACTUAL REPLACEMENT
  212. def replace():
  213. setattr(module, child_name, q_instance)
  214. def recurse_quantize():
  215. self._quantize_module_aux(
  216. module=getattr(module, child_name),
  217. mapping_instructions=mapping_instructions,
  218. nesting=nesting + (child_name,),
  219. preserve_state_dict=preserve_state_dict,
  220. )
  221. if metadata.action == QuantizedMetadata.ReplacementAction.REPLACE:
  222. replace()
  223. elif metadata.action == QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE:
  224. replace()
  225. recurse_quantize()
  226. elif metadata.action == QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE:
  227. recurse_quantize()
  228. replace()
  229. return True
  230. else:
  231. raise NotImplementedError
  232. return False
  233. def quantize_module(self, module: nn.Module, *, preserve_state_dict=True):
  234. per_layer_mappings = self._preprocess_skips_and_custom_mappings(module)
  235. mapping_instructions = {
  236. **per_layer_mappings,
  237. **self.mapping_instructions,
  238. } # we first regard the per layer mappings, and then override with the custom mappings in case there is overlap
  239. self._quantize_module_aux(mapping_instructions=mapping_instructions, module=module, nesting=(), preserve_state_dict=preserve_state_dict)
  240. def _quantize_module_aux(self, mapping_instructions, module, nesting, preserve_state_dict):
  241. for name, child_module in module.named_children():
  242. found = self._maybe_quantize_one_layer(module, name, nesting, child_module, mapping_instructions, preserve_state_dict)
  243. # RECURSIVE CALL, to support module_list, sequential, custom (nested) modules
  244. if not found and isinstance(child_module, nn.Module):
  245. self._quantize_module_aux(mapping_instructions, child_module, nesting + (name,), preserve_state_dict)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
  1. """
  2. Quantization utilities
  3. Methods are based on:
  4. https://github.com/NVIDIA/TensorRT/blob/51a4297753d3e12d0eed864be52400f429a6a94c/tools/pytorch-quantization/examples/torchvision/classification_flow.py#L385
  5. (Licensed under the Apache License, Version 2.0)
  6. """
  7. from torch.utils.data import DataLoader
  8. from tqdm import tqdm
  9. import torch
  10. from super_gradients.common.abstractions.abstract_logger import get_logger
  11. from super_gradients.training import models
  12. from super_gradients.training.utils.callbacks import Phase, PhaseCallback, PhaseContext
  13. import os
  14. from enum import Enum
  15. from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model
  16. from super_gradients.training.utils import get_param
  17. from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
  18. from torch.distributed import all_gather
  19. logger = get_logger(__name__)
  20. try:
  21. from pytorch_quantization import nn as quant_nn, quant_modules
  22. from pytorch_quantization import calib
  23. from pytorch_quantization.tensor_quant import QuantDescriptor
  24. _imported_pytorch_quantization_failure = None
  25. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  26. logger.debug("Failed to import pytorch_quantization")
  27. _imported_pytorch_quantization_failure = import_err
  28. class QuantizationLevel(str, Enum):
  29. FP32 = "FP32"
  30. FP16 = "FP16"
  31. INT8 = "INT8"
  32. HYBRID = "Hybrid"
  33. @staticmethod
  34. def from_string(quantization_level: str) -> Enum:
  35. quantization_level = quantization_level.lower()
  36. if quantization_level == "fp32":
  37. return QuantizationLevel.FP32
  38. elif quantization_level == "fp16":
  39. return QuantizationLevel.FP16
  40. elif quantization_level == "int8":
  41. return QuantizationLevel.INT8
  42. elif quantization_level == "hybrid":
  43. return QuantizationLevel.HYBRID
  44. else:
  45. raise NotImplementedError(f'Quantization Level: "{quantization_level}" is not supported')
  46. def export_qat_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple, per_channel_quantization: bool = False):
  47. """
  48. Method for exporting onnx after QAT.
  49. :param model: torch.nn.Module, model to export
  50. :param onnx_filename: str, target path for the onnx file,
  51. :param input_shape: tuple, input shape (usually BCHW)
  52. """
  53. if _imported_pytorch_quantization_failure is not None:
  54. raise _imported_pytorch_quantization_failure
  55. else:
  56. model.eval()
  57. if hasattr(model, "prep_model_for_conversion"):
  58. model.prep_model_for_conversion()
  59. quant_nn.TensorQuantizer.use_fb_fake_quant = True
  60. # Export ONNX for multiple batch sizes
  61. logger.info("Creating ONNX file: " + onnx_filename)
  62. dummy_input = torch.randn(input_shape, device="cuda")
  63. opset_version = 13 if per_channel_quantization else 12
  64. torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True)
  65. def calibrate_model(
  66. model: torch.nn.Module, calib_data_loader: torch.utils.data.DataLoader, method: str = "percentile", num_calib_batches: int = 2, percentile: float = 99.99
  67. ):
  68. """
  69. Calibrates torch model with quantized modules.
  70. :param model: torch.nn.Module, model to perfrom the calibration on.
  71. :param calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset.
  72. :param method: str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules
  73. (Default=percentile).
  74. :param num_calib_batches: int, number of batches to collect the statistics from.
  75. :param percentile: float, percentile value to use when Trainer,quant_modules_calib_method='percentile'. Discarded when other methods are used
  76. (Default=99.99).
  77. """
  78. if _imported_pytorch_quantization_failure is not None:
  79. raise _imported_pytorch_quantization_failure
  80. elif method in ["percentile", "mse", "entropy", "max"]:
  81. with torch.no_grad():
  82. _collect_stats(model, calib_data_loader, num_batches=num_calib_batches)
  83. # FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS,
  84. # SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT.
  85. if method == "precentile":
  86. _compute_amax(model, method="percentile", percentile=percentile)
  87. else:
  88. _compute_amax(model, method=method)
  89. else:
  90. raise ValueError("Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(method) + ".")
  91. def _collect_stats(model, data_loader, num_batches):
  92. """Feed data to the network and collect statistics"""
  93. if _imported_pytorch_quantization_failure is not None:
  94. raise _imported_pytorch_quantization_failure
  95. else:
  96. local_rank = get_local_rank()
  97. world_size = get_world_size()
  98. # Enable calibrators
  99. _enable_calibrators(model)
  100. # Feed data to the network for collecting stats
  101. for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
  102. if world_size > 1:
  103. all_batches = [torch.zeros_like(image, device="cuda") for _ in range(world_size)]
  104. all_gather(all_batches, image.cuda())
  105. else:
  106. all_batches = [image]
  107. for local_image in all_batches:
  108. model(local_image.cuda())
  109. if i >= num_batches:
  110. break
  111. # Disable calibrators
  112. _disable_calibrators(model)
  113. def _disable_calibrators(model):
  114. for name, module in model.named_modules():
  115. if isinstance(module, quant_nn.TensorQuantizer):
  116. if module._calibrator is not None:
  117. module.enable_quant()
  118. module.disable_calib()
  119. else:
  120. module.enable()
  121. def _enable_calibrators(model):
  122. for name, module in model.named_modules():
  123. if isinstance(module, quant_nn.TensorQuantizer):
  124. if module._calibrator is not None:
  125. module.disable_quant()
  126. module.enable_calib()
  127. else:
  128. module.disable()
  129. def _compute_amax(model, **kwargs):
  130. if _imported_pytorch_quantization_failure is not None:
  131. raise _imported_pytorch_quantization_failure
  132. else:
  133. # Load calib result
  134. for name, module in model.named_modules():
  135. if isinstance(module, quant_nn.TensorQuantizer):
  136. if module._calibrator is not None:
  137. if isinstance(module._calibrator, calib.MaxCalibrator):
  138. module.load_calib_amax()
  139. else:
  140. module.load_calib_amax(**kwargs)
  141. model.cuda()
  142. def _deactivate_quant_modules_wrapping():
  143. """
  144. Deactivates quant modules wrapping, so that further modules won't use Q/DQ layers.
  145. """
  146. if _imported_pytorch_quantization_failure is not None:
  147. raise _imported_pytorch_quantization_failure
  148. else:
  149. quant_modules.deactivate()
  150. def _activate_quant_modules_wrapping():
  151. """
  152. Activates quant modules wrapping, so that further modules use Q/DQ layers.
  153. """
  154. if _imported_pytorch_quantization_failure is not None:
  155. raise _imported_pytorch_quantization_failure
  156. else:
  157. quant_modules.initialize()
  158. class QATCallback(PhaseCallback):
  159. """
  160. A callback for transitioning training into QAT.
  161. Rebuilds the model with QAT layers then either:
  162. 1. loads the best checkpoint then performs calibration.
  163. 2. loads an external calibrated model (makes sense when start_epoch=0).
  164. Additionally, resets Trainer's best_metric and sets ckpt_best_name to 'qat_ckpt_best.pth' so best QAT checkpoints
  165. will be saved separately.
  166. If performing calibration- the calibrated model is evaluated, and the metric_to_watch is logged under
  167. calibrated_model_{metric_to_watch}. The calibrated checkpoint is saved under ckpt_calibrated_{calibration_method}.pth
  168. Attributes:
  169. start_epoch: int, first epoch to start QAT.
  170. quant_modules_calib_method: str, One of [percentile, mse, entropy, max]. Statistics method for amax
  171. computation of the quantized modules (default=percentile).
  172. per_channel_quant_modules: bool, whether quant modules should be per channel (default=False).
  173. calibrate: bool, whether to perfrom calibration (default=False).
  174. calibrated_model_path: str, path to a calibrated checkpoint (default=None).
  175. calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset. When None,
  176. context.train_loader will be used (default=None).
  177. num_calib_batches: int, number of batches to collect the statistics from.
  178. percentile: float, percentile value to use when Trainer,quant_modules_calib_method='percentile'.
  179. Discarded when other methods are used (Default=99.99).
  180. """
  181. def __init__(
  182. self,
  183. start_epoch: int,
  184. quant_modules_calib_method: str = "percentile",
  185. per_channel_quant_modules: bool = False,
  186. calibrate: bool = True,
  187. calibrated_model_path: str = None,
  188. calib_data_loader: DataLoader = None,
  189. num_calib_batches: int = 2,
  190. percentile: float = 99.99,
  191. ):
  192. super(QATCallback, self).__init__(Phase.TRAIN_EPOCH_START)
  193. self._validate_args(start_epoch, quant_modules_calib_method, calibrate, calibrated_model_path)
  194. self.start_epoch = start_epoch
  195. self.quant_modules_calib_method = quant_modules_calib_method
  196. self.per_channel_quant_modules = per_channel_quant_modules
  197. self.calibrate = calibrate
  198. self.calibrated_model_path = calibrated_model_path
  199. self.calib_data_loader = calib_data_loader
  200. self.num_calib_batches = num_calib_batches
  201. self.percentile = percentile
  202. def _validate_args(self, start_epoch: int, quant_modules_calib_method: str, calibrate, calibrated_model_path):
  203. if _imported_pytorch_quantization_failure:
  204. raise _imported_pytorch_quantization_failure
  205. if start_epoch < 0:
  206. raise ValueError("start_epoch must be positive.")
  207. if quant_modules_calib_method not in ["percentile", "mse", "entropy", "max"]:
  208. raise ValueError(
  209. "Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(self.quant_modules_calib_method) + "."
  210. )
  211. if not calibrate and calibrated_model_path is None:
  212. logger.warning("calibrate=False and no calibrated_model_path is given. QAT will be on an uncalibrated model.")
  213. def __call__(self, context: PhaseContext):
  214. if context.epoch == self.start_epoch:
  215. # REMOVE REFERENCES TO NETWORK AND CLEAN GPU MEMORY BEFORE BUILDING THE NEW NET
  216. context.context_methods.set_net(None)
  217. context.net = None
  218. torch.cuda.empty_cache()
  219. # BUILD THE SAME MODEL BUT WITH FAKE QUANTIZED MODULES, AND LOAD BEST CHECKPOINT TO IT
  220. self._initialize_quant_modules()
  221. if self.calibrated_model_path is not None:
  222. checkpoint_path = self.calibrated_model_path
  223. elif self.start_epoch > 0:
  224. checkpoint_path = os.path.join(context.ckpt_dir, "ckpt_best.pth")
  225. qat_net = models.get(context.architecture, arch_params=context.arch_params.to_dict(), checkpoint_path=checkpoint_path)
  226. _deactivate_quant_modules_wrapping()
  227. # UPDATE CONTEXT'S NET REFERENCE
  228. context.net = context.context_methods.get_net()
  229. if self.calibrate:
  230. self._calibrate_model(context)
  231. # RESET THE BEST METRIC VALUE SO WE SAVE CHECKPOINTS AFTER THE EXPECTED QAT ACCURACY DEGRADATION
  232. context.context_methods._reset_best_metric()
  233. # SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
  234. context.context_methods.set_ckpt_best_name("qat_ckpt_best.pth")
  235. # FINALLY, SET THE QAT NET TO CONTINUE TRAINING
  236. context.context_methods.set_net(qat_net)
  237. def _calibrate_model(self, context: PhaseContext):
  238. """
  239. Performs model calibration (collecting stats and setting amax for the fake quantized moduls)
  240. :param context: PhaseContext, current phase context.
  241. """
  242. self.calib_data_loader = self.calib_data_loader or context.train_loader
  243. calibrate_model(
  244. model=context.net,
  245. calib_data_loader=self.calib_data_loader,
  246. method=self.quant_modules_calib_method,
  247. num_calib_batches=self.num_calib_batches,
  248. percentile=self.percentile,
  249. )
  250. method_desc = (
  251. self.quant_modules_calib_method + "_" + str(self.percentile) if self.quant_modules_calib_method == "percentile" else self.quant_modules_calib_method
  252. )
  253. if not context.ddp_silent_mode:
  254. logger.info("Performing additional validation on calibrated model...")
  255. calibrated_valid_results = context.context_methods.validate_epoch(epoch=self.start_epoch, silent_mode=True)
  256. calibrated_acc = calibrated_valid_results[context.metric_idx_in_results_tuple]
  257. if not context.ddp_silent_mode:
  258. logger.info("Calibrate model " + context.metric_to_watch + ": " + str(calibrated_acc))
  259. context.sg_logger.add_checkpoint(tag="ckpt_calibrated_" + method_desc + ".pth", state_dict={"net": context.net.state_dict(), "acc": calibrated_acc})
  260. context.sg_logger.add_scalar("Calibrated_Model_" + context.metric_to_watch, calibrated_acc, global_step=self.start_epoch)
  261. def _initialize_quant_modules(self):
  262. """
  263. Initialize quant modules wrapping.
  264. """
  265. if _imported_pytorch_quantization_failure is not None:
  266. raise _imported_pytorch_quantization_failure
  267. else:
  268. if self.quant_modules_calib_method in ["percentile", "mse", "entropy"]:
  269. calib_method_type = "histogram"
  270. else:
  271. calib_method_type = "max"
  272. if self.per_channel_quant_modules:
  273. quant_desc_input = QuantDescriptor(calib_method=calib_method_type)
  274. quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
  275. quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
  276. else:
  277. quant_desc_input = QuantDescriptor(calib_method=calib_method_type, axis=None)
  278. quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
  279. quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input)
  280. quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
  281. quant_desc_weight = QuantDescriptor(calib_method=calib_method_type, axis=None)
  282. quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)
  283. quant_nn.QuantConvTranspose2d.set_default_quant_desc_weight(quant_desc_weight)
  284. quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)
  285. _activate_quant_modules_wrapping()
  286. class PostQATConversionCallback(PhaseCallback):
  287. """
  288. Post QAT training callback that saves the best checkpoint (i.e ckpt_best.pth) in onnx format.
  289. Should be used with QATCallback.
  290. Attributes:
  291. dummy_input_size: (tuple) dummy input size for the ONNX conversion.
  292. """
  293. def __init__(self, dummy_input_size):
  294. super().__init__(phase=Phase.POST_TRAINING)
  295. self.dummy_input_size = dummy_input_size
  296. def __call__(self, context: PhaseContext):
  297. if not context.ddp_silent_mode:
  298. best_ckpt_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.pth")
  299. onnx_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.onnx")
  300. load_checkpoint_to_model(
  301. ckpt_local_path=best_ckpt_path,
  302. net=context.net,
  303. load_weights_only=True,
  304. load_ema_as_net=context.training_params.ema,
  305. strict=True,
  306. load_backbone=False,
  307. )
  308. per_channel_quant_modules = get_param(context.training_params.qat_params, "per_channel_quant_modules")
  309. export_qat_onnx(context.net.module, onnx_path, self.dummy_input_size, per_channel_quant_modules)
  310. context.sg_logger.add_file("qat_ckpt_best.onnx")
Discard
@@ -22,6 +22,7 @@ from tests.unit_tests import (
 )
 )
 from tests.end_to_end_tests import TestTrainer
 from tests.end_to_end_tests import TestTrainer
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
+from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInitializedObjectsTest
 from tests.unit_tests.train_with_intialized_param_args_test import TrainWithInitializedObjectsTest
@@ -90,6 +91,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(QuantizationUtilityTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetCaching))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetCaching))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainingParamsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainingParamsTest))
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  1. import unittest
  2. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  3. from super_gradients.training import Trainer, MultiGPUMode, models
  4. from super_gradients.training.metrics.classification_metrics import Accuracy
  5. import os
  6. from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
  7. class QATIntegrationTest(unittest.TestCase):
  8. def _get_trainer(self, experiment_name):
  9. trainer = Trainer(experiment_name,
  10. multi_gpu=MultiGPUMode.OFF)
  11. model = models.get("resnet18", pretrained_weights="imagenet")
  12. return trainer, model
  13. def _get_train_params(self, qat_params):
  14. train_params = {"max_epochs": 2,
  15. "lr_mode": "step",
  16. "optimizer": "SGD",
  17. "lr_updates": [],
  18. "lr_decay_factor": 0.1,
  19. "initial_lr": 0.001, "loss": "cross_entropy",
  20. "train_metrics_list": [Accuracy()],
  21. "valid_metrics_list": [Accuracy()],
  22. "metric_to_watch": "Accuracy",
  23. "greater_metric_to_watch_is_better": True,
  24. "average_best_models": False,
  25. "enable_qat": True,
  26. "qat_params": qat_params,
  27. "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))]
  28. }
  29. return train_params
  30. def test_qat_from_start(self):
  31. model, net = self._get_trainer("test_qat_from_start")
  32. train_params = self._get_train_params(qat_params={
  33. "start_epoch": 0,
  34. "quant_modules_calib_method": "percentile",
  35. "calibrate": True,
  36. "num_calib_batches": 2,
  37. "percentile": 99.99
  38. })
  39. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
  40. valid_loader=classification_test_dataloader())
  41. def test_qat_transition(self):
  42. model, net = self._get_trainer("test_qat_transition")
  43. train_params = self._get_train_params(qat_params={
  44. "start_epoch": 1,
  45. "quant_modules_calib_method": "percentile",
  46. "calibrate": True,
  47. "num_calib_batches": 2,
  48. "percentile": 99.99
  49. })
  50. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
  51. valid_loader=classification_test_dataloader())
  52. def test_qat_from_calibrated_ckpt(self):
  53. model, net = self._get_trainer("generate_calibrated_model")
  54. train_params = self._get_train_params(qat_params={
  55. "start_epoch": 0,
  56. "quant_modules_calib_method": "percentile",
  57. "calibrate": True,
  58. "num_calib_batches": 2,
  59. "percentile": 99.99
  60. })
  61. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
  62. valid_loader=classification_test_dataloader())
  63. calibrated_model_path = os.path.join(model.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
  64. model, net = self._get_trainer("test_qat_from_calibrated_ckpt")
  65. train_params = self._get_train_params(qat_params={
  66. "start_epoch": 0,
  67. "quant_modules_calib_method": "percentile",
  68. "calibrate": False,
  69. "calibrated_model_path": calibrated_model_path,
  70. "num_calib_batches": 2,
  71. "percentile": 99.99
  72. })
  73. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(),
  74. valid_loader=classification_test_dataloader())
  75. if __name__ == '__main__':
  76. unittest.main()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
  1. import unittest
  2. import torch
  3. import torchvision
  4. from torch import nn
  5. try:
  6. import super_gradients
  7. from pytorch_quantization import nn as quant_nn
  8. from pytorch_quantization import quant_modules
  9. from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer, register_quantized_module
  10. from pytorch_quantization.calib import MaxCalibrator, HistogramCalibrator
  11. from super_gradients.training.utils.quantization.core import SkipQuantization, SGQuantMixin, QuantizedMapping, QuantizedMetadata
  12. from pytorch_quantization.nn import QuantConv2d
  13. from pytorch_quantization.tensor_quant import QuantDescriptor
  14. _imported_pytorch_quantization_failure = False
  15. except (ImportError, NameError, ModuleNotFoundError):
  16. _imported_pytorch_quantization_failure = True
  17. @unittest.skipIf(_imported_pytorch_quantization_failure, "Failed to import `pytorch_quantization`")
  18. class QuantizationUtilityTest(unittest.TestCase):
  19. def test_vanilla_replacement(self):
  20. # ARRANGE
  21. class MyModel(nn.Module):
  22. def __init__(self) -> None:
  23. super().__init__()
  24. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  25. def forward(self, x):
  26. return self.conv1(x)
  27. module = MyModel()
  28. # TEST
  29. q_util = SelectiveQuantizer()
  30. q_util.quantize_module(module)
  31. x = torch.rand(1, 3, 32, 32)
  32. # ASSERT
  33. with torch.no_grad():
  34. y = module(x)
  35. torch.testing.assert_close(y.size(), (1, 8, 32, 32))
  36. self.assertTrue(isinstance(module.conv1, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  37. def test_module_list_replacement(self):
  38. # ARRANGE
  39. class MyModel(nn.Module):
  40. def __init__(self) -> None:
  41. super().__init__()
  42. self.convs = nn.ModuleList([nn.Conv2d(3, 8, kernel_size=3, padding=1) for _ in range(3)])
  43. def forward(self, x):
  44. return torch.cat([conv(x) for conv in self.convs], dim=1)
  45. module = MyModel()
  46. # TEST
  47. q_util = SelectiveQuantizer()
  48. q_util.quantize_module(module)
  49. x = torch.rand(1, 3, 32, 32)
  50. # ASSERT
  51. with torch.no_grad():
  52. y = module(x)
  53. torch.testing.assert_close(y.size(), (1, 3 * 8, 32, 32))
  54. for conv in module.convs:
  55. self.assertTrue(isinstance(conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  56. def test_sequential_list_replacement(self):
  57. # ARRANGE
  58. class MyModel(nn.Module):
  59. def __init__(self) -> None:
  60. super().__init__()
  61. self.convs = nn.Sequential(
  62. nn.Conv2d(3, 8, kernel_size=3, padding=1),
  63. nn.Conv2d(8, 16, kernel_size=3, padding=1),
  64. )
  65. def forward(self, x):
  66. return self.convs(x)
  67. module = MyModel()
  68. # TEST
  69. q_util = SelectiveQuantizer()
  70. q_util.quantize_module(module)
  71. x = torch.rand(1, 3, 32, 32)
  72. # ASSERT
  73. with torch.no_grad():
  74. y = module(x)
  75. torch.testing.assert_close(y.size(), (1, 16, 32, 32))
  76. for conv in module.convs:
  77. self.assertTrue(isinstance(conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  78. def test_nested_module_replacement(self):
  79. # ARRANGE
  80. class MyBlock(nn.Module):
  81. def __init__(self, in_feats, out_feats) -> None:
  82. super().__init__()
  83. self.flatten = nn.Flatten()
  84. self.linear = nn.Linear(in_feats, out_feats)
  85. def forward(self, x):
  86. return self.linear(self.flatten(x))
  87. class MyModel(nn.Module):
  88. def __init__(self, res, n_classes) -> None:
  89. super().__init__()
  90. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  91. self.my_block = MyBlock(4 * (res**2), n_classes)
  92. def forward(self, x):
  93. y = self.conv(x)
  94. return self.my_block(y)
  95. res = 32
  96. n_clss = 10
  97. module = MyModel(res, n_clss)
  98. # TEST
  99. q_util = SelectiveQuantizer()
  100. q_util.quantize_module(module)
  101. x = torch.rand(1, 3, res, res)
  102. # ASSERT
  103. with torch.no_grad():
  104. y = module(x)
  105. torch.testing.assert_close(y.size(), (1, n_clss))
  106. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  107. self.assertTrue(isinstance(module.my_block.linear, SelectiveQuantizer.mapping_instructions[nn.Linear].quantized_target_class))
  108. def test_static_selective_skip_quantization(self):
  109. # ARRANGE
  110. class MyModel(nn.Module):
  111. def __init__(self) -> None:
  112. super().__init__()
  113. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  114. self.conv2 = SkipQuantization(nn.Conv2d(8, 16, kernel_size=3, padding=1))
  115. def forward(self, x):
  116. return self.conv2(self.conv1(x))
  117. module = MyModel()
  118. # TEST
  119. q_util = SelectiveQuantizer()
  120. q_util.quantize_module(module)
  121. x = torch.rand(1, 3, 32, 32)
  122. # ASSERT
  123. with torch.no_grad():
  124. y = module(x)
  125. torch.testing.assert_close(y.size(), (1, 16, 32, 32))
  126. self.assertTrue(isinstance(module.conv1, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  127. self.assertTrue(isinstance(module.conv2, nn.Conv2d))
  128. def test_dynamic_skip_quantization(self):
  129. # ARRANGE
  130. class MyModel(nn.Module):
  131. def __init__(self) -> None:
  132. super().__init__()
  133. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  134. self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
  135. def forward(self, x):
  136. return self.conv2(self.conv1(x))
  137. module = MyModel()
  138. # TEST
  139. q_util = SelectiveQuantizer()
  140. q_util.register_skip_quantization(layer_names={"conv2"})
  141. q_util.quantize_module(module)
  142. x = torch.rand(1, 3, 32, 32)
  143. # ASSERT
  144. with torch.no_grad():
  145. y = module(x)
  146. torch.testing.assert_close(y.size(), (1, 16, 32, 32))
  147. self.assertTrue(isinstance(module.conv1, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  148. self.assertTrue(isinstance(module.conv2, nn.Conv2d))
  149. def test_custom_quantized_mapping_wrapper_explicit_from_float(self):
  150. # ARRANGE
  151. class MyBlock(nn.Module):
  152. def __init__(self, in_feats, out_feats) -> None:
  153. super().__init__()
  154. self.flatten = nn.Flatten()
  155. self.linear = nn.Linear(in_feats, out_feats)
  156. def forward(self, x):
  157. return self.linear(self.flatten(x))
  158. class MyQuantizedBlock(SGQuantMixin):
  159. # NOTE: **kwargs are necessary because quant descriptors are passed there!
  160. @classmethod
  161. def from_float(cls, float_instance: MyBlock, **kwargs):
  162. return cls(in_feats=float_instance.linear.in_features, out_feats=float_instance.linear.out_features)
  163. def __init__(self, in_feats, out_feats) -> None:
  164. super().__init__()
  165. self.flatten = nn.Flatten()
  166. self.linear = quant_nn.QuantLinear(in_feats, out_feats)
  167. def forward(self, x):
  168. return self.linear(self.flatten(x))
  169. class MyModel(nn.Module):
  170. def __init__(self, res, n_classes) -> None:
  171. super().__init__()
  172. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  173. self.my_block = QuantizedMapping(float_module=MyBlock(4 * (res**2), n_classes), quantized_target_class=MyQuantizedBlock)
  174. def forward(self, x):
  175. y = self.conv(x)
  176. return self.my_block(y)
  177. res = 32
  178. n_clss = 10
  179. module = MyModel(res, n_clss)
  180. # TEST
  181. q_util = SelectiveQuantizer()
  182. q_util.quantize_module(module)
  183. x = torch.rand(1, 3, res, res)
  184. # ASSERT
  185. with torch.no_grad():
  186. y = module(x)
  187. torch.testing.assert_close(y.size(), (1, n_clss))
  188. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  189. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  190. def test_custom_quantized_mapping_wrapper_implicit_from_float(self):
  191. # ARRANGE
  192. class MyBlock(nn.Module):
  193. def __init__(self, in_feats, out_feats) -> None:
  194. super().__init__()
  195. self.in_feats = in_feats
  196. self.out_feats = out_feats
  197. self.flatten = nn.Flatten()
  198. self.linear = nn.Linear(in_feats, out_feats)
  199. def forward(self, x):
  200. return self.linear(self.flatten(x))
  201. class MyQuantizedBlock(SGQuantMixin):
  202. def __init__(self, in_feats, out_feats) -> None:
  203. super().__init__()
  204. self.flatten = nn.Flatten()
  205. self.linear = quant_nn.QuantLinear(in_feats, out_feats)
  206. def forward(self, x):
  207. return self.linear(self.flatten(x))
  208. class MyModel(nn.Module):
  209. def __init__(self, res, n_classes) -> None:
  210. super().__init__()
  211. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  212. self.my_block = QuantizedMapping(float_module=MyBlock(4 * (res**2), n_classes), quantized_target_class=MyQuantizedBlock)
  213. def forward(self, x):
  214. y = self.conv(x)
  215. return self.my_block(y)
  216. res = 32
  217. n_clss = 10
  218. module = MyModel(res, n_clss)
  219. # TEST
  220. q_util = SelectiveQuantizer()
  221. q_util.quantize_module(module)
  222. x = torch.rand(1, 3, res, res)
  223. # ASSERT
  224. with torch.no_grad():
  225. y = module(x)
  226. torch.testing.assert_close(y.size(), (1, n_clss))
  227. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  228. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  229. def test_custom_quantized_mapping_register_with_decorator(self):
  230. # ARRANGE
  231. class MyBlock(nn.Module):
  232. def __init__(self, in_feats, out_feats) -> None:
  233. super().__init__()
  234. self.in_feats = in_feats
  235. self.out_feats = out_feats
  236. self.flatten = nn.Flatten()
  237. self.linear = nn.Linear(in_feats, out_feats)
  238. def forward(self, x):
  239. return self.linear(self.flatten(x))
  240. @register_quantized_module(float_source=MyBlock)
  241. class MyQuantizedBlock(SGQuantMixin):
  242. def __init__(self, in_feats, out_feats) -> None:
  243. super().__init__()
  244. self.flatten = nn.Flatten()
  245. self.linear = quant_nn.QuantLinear(in_feats, out_feats)
  246. def forward(self, x):
  247. return self.linear(self.flatten(x))
  248. class MyModel(nn.Module):
  249. def __init__(self, res, n_classes) -> None:
  250. super().__init__()
  251. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  252. self.my_block = MyBlock(4 * (res**2), n_classes)
  253. def forward(self, x):
  254. y = self.conv(x)
  255. return self.my_block(y)
  256. res = 32
  257. n_clss = 10
  258. module = MyModel(res, n_clss)
  259. # TEST
  260. q_util = SelectiveQuantizer()
  261. q_util.quantize_module(module)
  262. x = torch.rand(1, 3, res, res)
  263. # ASSERT
  264. with torch.no_grad():
  265. y = module(x)
  266. torch.testing.assert_close(y.size(), (1, n_clss))
  267. self.assertTrue(MyQuantizedBlock is not None)
  268. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  269. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  270. def test_dynamic_quantized_mapping(self):
  271. # ARRANGE
  272. class MyBlock(nn.Module):
  273. def __init__(self, in_feats, out_feats) -> None:
  274. super().__init__()
  275. self.in_feats = in_feats
  276. self.out_feats = out_feats
  277. self.flatten = nn.Flatten()
  278. self.linear = nn.Linear(in_feats, out_feats)
  279. def forward(self, x):
  280. return self.linear(self.flatten(x))
  281. class MyQuantizedBlock(SGQuantMixin):
  282. def __init__(self, in_feats, out_feats) -> None:
  283. super().__init__()
  284. self.flatten = nn.Flatten()
  285. self.linear = quant_nn.QuantLinear(in_feats, out_feats)
  286. def forward(self, x):
  287. return self.linear(self.flatten(x))
  288. class MyModel(nn.Module):
  289. def __init__(self, res, n_classes) -> None:
  290. super().__init__()
  291. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  292. self.my_block = MyBlock(4 * (res**2), n_classes)
  293. def forward(self, x):
  294. y = self.conv(x)
  295. return self.my_block(y)
  296. res = 32
  297. n_clss = 10
  298. module = MyModel(res, n_clss)
  299. # TEST
  300. q_util = SelectiveQuantizer()
  301. q_util.register_quantization_mapping(layer_names={"my_block"}, quantized_target_class=MyQuantizedBlock)
  302. q_util.quantize_module(module)
  303. x = torch.rand(1, 3, res, res)
  304. # ASSERT
  305. with torch.no_grad():
  306. y = module(x)
  307. torch.testing.assert_close(y.size(), (1, n_clss))
  308. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  309. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  310. def test_non_default_quant_descriptors_are_piped(self):
  311. # ARRANGE
  312. class MyModel(nn.Module):
  313. def __init__(self) -> None:
  314. super().__init__()
  315. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  316. def forward(self, x):
  317. return self.conv1(x)
  318. module = MyModel()
  319. # TEST
  320. q_util = SelectiveQuantizer(default_quant_modules_calib_method="max")
  321. q_util.quantize_module(module)
  322. x = torch.rand(1, 3, 32, 32)
  323. # ASSERT
  324. with torch.no_grad():
  325. y = module(x)
  326. torch.testing.assert_close(y.size(), (1, 8, 32, 32))
  327. self.assertTrue(isinstance(module.conv1, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  328. self.assertTrue(type(module.conv1._input_quantizer._calibrator) == MaxCalibrator)
  329. self.assertTrue(type(module.conv1._weight_quantizer._calibrator) == MaxCalibrator)
  330. def test_different_quant_descriptors_are_piped(self):
  331. # ARRANGE
  332. class MyModel(nn.Module):
  333. def __init__(self) -> None:
  334. super().__init__()
  335. self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
  336. self.conv2 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
  337. def forward(self, x):
  338. return self.conv2(self.conv1(x))
  339. module = MyModel()
  340. # TEST
  341. q_util = SelectiveQuantizer()
  342. q_util.register_quantization_mapping(
  343. layer_names={"conv1"},
  344. quantized_target_class=QuantConv2d,
  345. input_quant_descriptor=QuantDescriptor(calib_method="max"),
  346. weights_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  347. )
  348. q_util.register_quantization_mapping(
  349. layer_names={"conv2"},
  350. quantized_target_class=QuantConv2d,
  351. input_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  352. weights_quant_descriptor=QuantDescriptor(calib_method="max"),
  353. )
  354. q_util.quantize_module(module)
  355. x = torch.rand(1, 3, 32, 32)
  356. # ASSERT
  357. with torch.no_grad():
  358. y = module(x)
  359. torch.testing.assert_close(y.size(), (1, 8, 32, 32))
  360. self.assertTrue(isinstance(module.conv1, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  361. self.assertTrue(type(module.conv1._input_quantizer._calibrator) == MaxCalibrator)
  362. self.assertTrue(type(module.conv1._weight_quantizer._calibrator) == HistogramCalibrator)
  363. self.assertTrue(type(module.conv2._input_quantizer._calibrator) == HistogramCalibrator)
  364. self.assertTrue(type(module.conv2._weight_quantizer._calibrator) == MaxCalibrator)
  365. def test_quant_descriptors_are_piped_to_custom_quant_modules_if_has_kwargs(self):
  366. # ARRANGE
  367. class MyBlock(nn.Module):
  368. def __init__(self, in_feats, out_feats) -> None:
  369. super().__init__()
  370. self.in_feats = in_feats
  371. self.out_feats = out_feats
  372. self.flatten = nn.Flatten()
  373. self.linear = nn.Linear(in_feats, out_feats)
  374. def forward(self, x):
  375. return self.linear(self.flatten(x))
  376. class MyQuantizedBlock(SGQuantMixin):
  377. # NOTE: if **kwargs are existing, then quant descriptors are passed there!
  378. # NOTE: because we don't override `from_float`,
  379. # then the float instance should have `in_feats` and `out_feats` as state
  380. def __init__(self, in_feats, out_feats, **kwargs) -> None:
  381. super().__init__()
  382. self.flatten = nn.Flatten()
  383. self.linear = quant_nn.QuantLinear(
  384. in_feats,
  385. out_feats,
  386. quant_desc_input=kwargs["quant_desc_input"],
  387. quant_desc_weight=kwargs["quant_desc_weight"],
  388. )
  389. def forward(self, x):
  390. return self.linear(self.flatten(x))
  391. class MyModel(nn.Module):
  392. def __init__(self, res, n_classes) -> None:
  393. super().__init__()
  394. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  395. self.my_block = QuantizedMapping(
  396. float_module=MyBlock(4 * (res**2), n_classes),
  397. quantized_target_class=MyQuantizedBlock,
  398. input_quant_descriptor=QuantDescriptor(calib_method="max"),
  399. weights_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  400. )
  401. def forward(self, x):
  402. y = self.conv(x)
  403. return self.my_block(y)
  404. res = 32
  405. n_clss = 10
  406. module = MyModel(res, n_clss)
  407. # TEST
  408. q_util = SelectiveQuantizer()
  409. q_util.quantize_module(module)
  410. x = torch.rand(1, 3, res, res)
  411. # ASSERT
  412. with torch.no_grad():
  413. y = module(x)
  414. torch.testing.assert_close(y.size(), (1, n_clss))
  415. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  416. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  417. self.assertTrue(type(module.my_block.linear._input_quantizer._calibrator) == MaxCalibrator)
  418. self.assertTrue(type(module.my_block.linear._weight_quantizer._calibrator) == HistogramCalibrator)
  419. def test_quant_descriptors_are_piped_to_custom_quant_modules_if_expects_in_init(self):
  420. # ARRANGE
  421. class MyBlock(nn.Module):
  422. def __init__(self, in_feats, out_feats) -> None:
  423. super().__init__()
  424. self.in_feats = in_feats
  425. self.out_feats = out_feats
  426. self.flatten = nn.Flatten()
  427. self.linear = nn.Linear(in_feats, out_feats)
  428. def forward(self, x):
  429. return self.linear(self.flatten(x))
  430. class MyQuantizedBlock(SGQuantMixin):
  431. # NOTE: `since quant_desc_input`, `quant_desc_weight` are existing, then quant descriptors are passed there!
  432. # NOTE: because we don't override `from_float`,
  433. # then the float instance should have `in_feats` and `out_feats` as state
  434. def __init__(self, in_feats, out_feats, quant_desc_input, quant_desc_weight) -> None:
  435. super().__init__()
  436. self.flatten = nn.Flatten()
  437. self.linear = quant_nn.QuantLinear(
  438. in_feats,
  439. out_feats,
  440. quant_desc_input=quant_desc_input,
  441. quant_desc_weight=quant_desc_weight,
  442. )
  443. def forward(self, x):
  444. return self.linear(self.flatten(x))
  445. class MyModel(nn.Module):
  446. def __init__(self, res, n_classes) -> None:
  447. super().__init__()
  448. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  449. self.my_block = QuantizedMapping(
  450. float_module=MyBlock(4 * (res**2), n_classes),
  451. quantized_target_class=MyQuantizedBlock,
  452. input_quant_descriptor=QuantDescriptor(calib_method="max"),
  453. weights_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  454. )
  455. def forward(self, x):
  456. y = self.conv(x)
  457. return self.my_block(y)
  458. res = 32
  459. n_clss = 10
  460. module = MyModel(res, n_clss)
  461. # TEST
  462. q_util = SelectiveQuantizer()
  463. q_util.quantize_module(module)
  464. x = torch.rand(1, 3, res, res)
  465. # ASSERT
  466. with torch.no_grad():
  467. y = module(x)
  468. torch.testing.assert_close(y.size(), (1, n_clss))
  469. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  470. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  471. self.assertTrue(type(module.my_block.linear._input_quantizer._calibrator) == MaxCalibrator)
  472. self.assertTrue(type(module.my_block.linear._weight_quantizer._calibrator) == HistogramCalibrator)
  473. def test_quant_descriptors_are_not_piped_if_custom_quant_module_does_not_expect_them(self):
  474. # ARRANGE
  475. class MyBlock(nn.Module):
  476. def __init__(self, in_feats, out_feats) -> None:
  477. super().__init__()
  478. self.in_feats = in_feats
  479. self.out_feats = out_feats
  480. self.flatten = nn.Flatten()
  481. self.linear = nn.Linear(in_feats, out_feats)
  482. def forward(self, x):
  483. return self.linear(self.flatten(x))
  484. class MyQuantizedBlock(SGQuantMixin):
  485. # NOTE: because we don't override `from_float`,
  486. # then the float instance should have `in_feats` and `out_feats` as state
  487. def __init__(self, in_feats, out_feats) -> None:
  488. super().__init__()
  489. self.flatten = nn.Flatten()
  490. self.linear = quant_nn.QuantLinear(in_feats, out_feats)
  491. def forward(self, x):
  492. return self.linear(self.flatten(x))
  493. class MyModel(nn.Module):
  494. def __init__(self, res, n_classes) -> None:
  495. super().__init__()
  496. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  497. self.my_block = QuantizedMapping(float_module=MyBlock(4 * (res**2), n_classes), quantized_target_class=MyQuantizedBlock)
  498. def forward(self, x):
  499. y = self.conv(x)
  500. return self.my_block(y)
  501. res = 32
  502. n_clss = 10
  503. module = MyModel(res, n_clss)
  504. # TEST
  505. q_util = SelectiveQuantizer()
  506. q_util.quantize_module(module)
  507. x = torch.rand(1, 3, res, res)
  508. # ASSERT
  509. with torch.no_grad():
  510. y = module(x)
  511. torch.testing.assert_close(y.size(), (1, n_clss))
  512. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  513. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  514. def test_custom_quantized_mappings_are_recursively_quantized_if_required(self):
  515. # ARRANGE
  516. class MyBlock(nn.Module):
  517. def __init__(self, in_feats, out_feats) -> None:
  518. super().__init__()
  519. self.in_feats = in_feats
  520. self.out_feats = out_feats
  521. self.flatten = nn.Flatten()
  522. self.linear = nn.Linear(in_feats, out_feats)
  523. def forward(self, x):
  524. return self.linear(self.flatten(x))
  525. class MyQuantizedBlock(SGQuantMixin):
  526. def __init__(self, in_feats, out_feats) -> None:
  527. super().__init__()
  528. self.flatten = nn.Flatten()
  529. self.linear = nn.Linear(in_feats, out_feats)
  530. def forward(self, x):
  531. return self.linear(self.flatten(x))
  532. class MyModel(nn.Module):
  533. def __init__(self, res, n_classes) -> None:
  534. super().__init__()
  535. self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
  536. self.my_block = QuantizedMapping(
  537. float_module=MyBlock(4 * (res**2), n_classes),
  538. quantized_target_class=MyQuantizedBlock,
  539. action=QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
  540. )
  541. def forward(self, x):
  542. y = self.conv(x)
  543. return self.my_block(y)
  544. res = 32
  545. n_clss = 10
  546. module = MyModel(res, n_clss)
  547. # TEST
  548. q_util = SelectiveQuantizer()
  549. q_util.quantize_module(module)
  550. x = torch.rand(1, 3, res, res)
  551. # ASSERT
  552. with torch.no_grad():
  553. y = module(x)
  554. torch.testing.assert_close(y.size(), (1, n_clss))
  555. self.assertTrue(isinstance(module.conv, SelectiveQuantizer.mapping_instructions[nn.Conv2d].quantized_target_class))
  556. self.assertTrue(isinstance(module.my_block, MyQuantizedBlock))
  557. self.assertTrue(isinstance(module.my_block.linear, SelectiveQuantizer.mapping_instructions[nn.Linear].quantized_target_class))
  558. def test_torchvision_resnet_sg_vanilla_quantization_matches_pytorch_quantization(self):
  559. resnet_sg = torchvision.models.resnet50(pretrained=True)
  560. # SG SELECTIVE QUANTIZATION
  561. sq = SelectiveQuantizer(
  562. custom_mappings={
  563. torch.nn.Conv2d: QuantizedMetadata(
  564. torch.nn.Conv2d,
  565. quant_nn.QuantConv2d,
  566. action=QuantizedMetadata.ReplacementAction.REPLACE,
  567. input_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  568. weights_quant_descriptor=QuantDescriptor(calib_method="max", axis=0),
  569. ),
  570. torch.nn.Linear: QuantizedMetadata(
  571. torch.nn.Linear,
  572. quant_nn.QuantLinear,
  573. action=QuantizedMetadata.ReplacementAction.REPLACE,
  574. input_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  575. weights_quant_descriptor=QuantDescriptor(calib_method="max", axis=0),
  576. ),
  577. torch.nn.AdaptiveAvgPool2d: QuantizedMetadata(
  578. torch.nn.AdaptiveAvgPool2d,
  579. quant_nn.QuantAdaptiveAvgPool2d,
  580. action=QuantizedMetadata.ReplacementAction.REPLACE,
  581. input_quant_descriptor=QuantDescriptor(calib_method="max"),
  582. ),
  583. },
  584. default_per_channel_quant_modules=True,
  585. )
  586. sq.quantize_module(resnet_sg, preserve_state_dict=True)
  587. # PYTORCH-QUANTIZATION
  588. quant_desc_input = QuantDescriptor(calib_method="histogram")
  589. quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
  590. quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
  591. quant_modules.initialize()
  592. resnet_pyquant = torchvision.models.resnet50(pretrained=True)
  593. quant_modules.deactivate()
  594. for (n1, p1), (n2, p2) in zip(resnet_sg.named_parameters(), resnet_pyquant.named_parameters()):
  595. torch.testing.assert_allclose(p1, p2)
  596. x = torch.rand(1, 3, 128, 128)
  597. with torch.no_grad():
  598. y_pyquant = resnet_pyquant(x)
  599. y_sg = resnet_sg(x)
  600. torch.testing.assert_close(y_sg, y_pyquant)
  601. def test_sg_resnet_sg_vanilla_quantization_matches_pytorch_quantization(self):
  602. # SG SELECTIVE QUANTIZATION
  603. sq = SelectiveQuantizer(
  604. custom_mappings={
  605. torch.nn.Conv2d: QuantizedMetadata(
  606. torch.nn.Conv2d,
  607. quant_nn.QuantConv2d,
  608. action=QuantizedMetadata.ReplacementAction.REPLACE,
  609. input_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  610. weights_quant_descriptor=QuantDescriptor(calib_method="max", axis=0),
  611. ),
  612. torch.nn.Linear: QuantizedMetadata(
  613. torch.nn.Linear,
  614. quant_nn.QuantLinear,
  615. action=QuantizedMetadata.ReplacementAction.REPLACE,
  616. input_quant_descriptor=QuantDescriptor(calib_method="histogram"),
  617. weights_quant_descriptor=QuantDescriptor(calib_method="max", axis=0),
  618. ),
  619. torch.nn.AdaptiveAvgPool2d: QuantizedMetadata(
  620. torch.nn.AdaptiveAvgPool2d,
  621. quant_nn.QuantAdaptiveAvgPool2d,
  622. action=QuantizedMetadata.ReplacementAction.REPLACE,
  623. input_quant_descriptor=QuantDescriptor(calib_method="max"),
  624. ),
  625. },
  626. default_per_channel_quant_modules=True,
  627. )
  628. resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
  629. sq.quantize_module(resnet_sg, preserve_state_dict=True)
  630. # PYTORCH-QUANTIZATION
  631. quant_desc_input = QuantDescriptor(calib_method="histogram")
  632. quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
  633. quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
  634. quant_modules.initialize()
  635. resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
  636. quant_modules.deactivate()
  637. for (n1, p1), (n2, p2) in zip(resnet_sg.named_parameters(), resnet_pyquant.named_parameters()):
  638. torch.testing.assert_allclose(p1, p2)
  639. x = torch.rand(1, 3, 128, 128)
  640. with torch.no_grad():
  641. y_pyquant = resnet_pyquant(x)
  642. y_sg = resnet_sg(x)
  643. torch.testing.assert_close(y_sg, y_pyquant)
  644. if __name__ == "__main__":
  645. unittest.main()
Discard