Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#537 Quantization infra mods for different calibrators and learnable amax

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/AL-706-selective-qat
@@ -16,7 +16,12 @@ def non_default_calibrators_example():
     module = MyModel()
 
     # Initialize the quantization utility, with different calibrators, and quantize the module
-    q_util = SelectiveQuantizer(default_quant_modules_calib_method="percentile", default_per_channel_quant_modules=False)
+    q_util = SelectiveQuantizer(
+        default_quant_modules_calib_method_weights="percentile",
+        default_quant_modules_calib_method_inputs="entropy",
+        default_per_channel_quant_weights=False,
+        default_learn_amax=False,
+    )
     q_util.quantize_module(module)
 
     print(module)  # You should expect to see QuantConv2d, with Histogram calibrators
Discard
@@ -50,7 +50,7 @@ def e2e_example():
     # CALIBRATE (PTQ)
     train_loader = cifar10_train()
     calib = QuantizationCalibrator()
-    calib.calibrate_model(module, method=q_util.default_quant_modules_calib_method, calib_data_loader=train_loader)
+    calib.calibrate_model(module, method=q_util.default_quant_modules_calib_method_inputs, calib_data_loader=train_loader)
 
     module.cuda()
     # SANITY
Discard
@@ -4,21 +4,30 @@ from torch import nn
 
 import super_gradients
 from super_gradients import Trainer
+from super_gradients.modules.quantization.resnet_bottleneck import QuantBottleneck as sg_QuantizedBottleneck
 from super_gradients.training import MultiGPUMode
 from super_gradients.training import models as sg_models
 from super_gradients.training.dataloaders import imagenet_train, imagenet_val
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics.metric_utils import get_metrics_dict
+from super_gradients.training.models.classification_models.resnet import Bottleneck
 from super_gradients.training.models.classification_models.resnet import Bottleneck as sg_Bottleneck
 from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
 from super_gradients.training.utils.quantization.core import QuantizedMetadata
 from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
-from super_gradients.modules.quantization.resnet_bottleneck import QuantBottleneck as sg_QuantizedBottleneck
 from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
 
 
 def naive_quantize(model: nn.Module):
-    q_util = SelectiveQuantizer(default_quant_modules_calib_method="max", default_per_channel_quant_modules=True)
+    q_util = SelectiveQuantizer(
+        default_quant_modules_calib_method_weights="max",
+        default_quant_modules_calib_method_inputs="percentile",
+        default_per_channel_quant_weights=True,
+        default_learn_amax=False,
+    )
+    # SG already registers non-naive QuantBottleneck as in selective_quantize() down there, pop it for the sake of example
+    if Bottleneck in q_util.mapping_instructions:
+        q_util.mapping_instructions.pop(Bottleneck)
     q_util.quantize_module(model)
 
     return model
@@ -33,7 +42,13 @@ def selective_quantize(model: nn.Module):
         ),
     }
 
-    sq_util = SelectiveQuantizer(custom_mappings=mappings, default_quant_modules_calib_method="max", default_per_channel_quant_modules=True)
+    sq_util = SelectiveQuantizer(
+        custom_mappings=mappings,
+        default_quant_modules_calib_method_weights="max",
+        default_quant_modules_calib_method_inputs="percentile",
+        default_per_channel_quant_weights=True,
+        default_learn_amax=False,
+    )
     sq_util.quantize_module(model)
 
     return model
@@ -91,8 +106,8 @@ if __name__ == "__main__":
     model = models[args.model_name]().cuda()
 
     if args.calibrate:
-        calibrator = QuantizationCalibrator(verbose=False)
-        calibrator.calibrate_model(model, method="max", calib_data_loader=train_dataloader, num_calib_batches=1024 // args.batch or 1)
+        calibrator = QuantizationCalibrator(verbose=True)
+        calibrator.calibrate_model(model, method="percentile", calib_data_loader=train_dataloader, num_calib_batches=1024 // args.batch or 1)
 
     trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)
 
Discard
@@ -64,7 +64,9 @@ class XBlock(nn.Module):  # From figure 4
                 nn.ReLU(),
                 nn.Conv2d(se_channels, inter_channels, kernel_size=1, bias=True),
                 nn.Sigmoid(),
+                Residual(),
             )
+            self.se_residual = Residual()
         else:
             self.se = None
 
@@ -82,7 +84,7 @@ class XBlock(nn.Module):  # From figure 4
         x1 = self.conv_block_1(x)
         x1 = self.conv_block_2(x1)
         if self.se is not None:
-            x1 = x1 * self.se(x1)
+            x1 = self.se_residual(x1) * self.se(x1)
 
         x1 = self.conv_block_3(x1)
         x2 = self.shortcut(x)
Discard
@@ -27,11 +27,12 @@ except (ImportError, NameError, ModuleNotFoundError) as import_err:
 
 
 class QuantizationCalibrator:
-    def __init__(self, verbose: bool = True) -> None:
+    def __init__(self, torch_hist: bool = True, verbose: bool = True) -> None:
         if _imported_pytorch_quantization_failure is not None:
             raise _imported_pytorch_quantization_failure
         super().__init__()
         self.verbose = verbose
+        self.torch_hist = torch_hist
 
     def calibrate_model(
         self,
@@ -98,8 +99,8 @@ class QuantizationCalibrator:
         for name, module in model.named_modules():
             if isinstance(module, quant_nn.TensorQuantizer):
                 if module._calibrator is not None:
-                    module.enable_quant()
                     module.disable_calib()
+                    module.enable_quant()
                 else:
                     module.enable()
 
@@ -107,6 +108,8 @@ class QuantizationCalibrator:
         for name, module in model.named_modules():
             if isinstance(module, quant_nn.TensorQuantizer):
                 if module._calibrator is not None:
+                    if isinstance(module._calibrator, calib.HistogramCalibrator):
+                        module._calibrator._torch_hist = self.torch_hist  # TensorQuantizer does not expose it as API
                     module.disable_quant()
                     module.enable_calib()
                 else:
@@ -120,5 +123,9 @@ class QuantizationCalibrator:
                         module.load_calib_amax()
                     else:
                         module.load_calib_amax(**kwargs)
+
+                if hasattr(module, "clip"):
+                    module.init_learn_amax()
+
                 if self.verbose:
                     print(f"{name:40}: {module}")
Discard
@@ -32,7 +32,6 @@ def register_quantized_module(
     """
 
     def decorator(quant_module: Type[SGQuantMixin]) -> Type[SGQuantMixin]:
-
         if float_source in SelectiveQuantizer.mapping_instructions:
             metadata = SelectiveQuantizer.mapping_instructions[float_source]
             raise ValueError(f"`{float_source}` is already registered with following metadata {metadata}")
@@ -54,11 +53,12 @@ def register_quantized_module(
 
 
 class SelectiveQuantizer:
-
     """
-    :param custom_mappings:                     custom mappings that extend the default mappings with extra behaviour
-    :param default_quant_modules_calib_method:  default calibration method (default='percentile')
-    :param default_per_channel_quant_modules:   whether quant modules should be per channel (default=False)
+    :param custom_mappings:                             custom mappings that extend the default mappings with extra behaviour
+    :param default_per_channel_quant_weights:           whether quant module weights should be per channel (default=True)
+    :param default_quant_modules_calib_method_weights:  default calibration method for weights (default='max')
+    :param default_quant_modules_calib_method_inputs:   default calibration method for inputs (default='percentile')
+    :param default_learn_amax:                          EXPERIMENTAL! whether quant modules should have learnable amax (default=False)
     """
 
     if _imported_pytorch_quantization_failure is not None:
@@ -93,24 +93,38 @@ class SelectiveQuantizer:
     }  # DEFAULT MAPPING INSTRUCTIONS
 
     def __init__(
-        self, *, custom_mappings: dict = None, default_quant_modules_calib_method: str = "max", default_per_channel_quant_modules: bool = True
+        self,
+        *,
+        custom_mappings: dict = None,
+        default_quant_modules_calib_method_weights: str = "max",
+        default_quant_modules_calib_method_inputs: str = "percentile",
+        default_per_channel_quant_weights: bool = True,
+        default_learn_amax: bool = False,
     ) -> None:
         super().__init__()
-        self.default_quant_modules_calib_method = default_quant_modules_calib_method
-        self.default_per_channel_quant_modules = default_per_channel_quant_modules
+        self.default_quant_modules_calib_method_weights = default_quant_modules_calib_method_weights
+        self.default_quant_modules_calib_method_inputs = default_quant_modules_calib_method_inputs
+        self.default_per_channel_quant_weights = default_per_channel_quant_weights
+        self.default_learn_amax = default_learn_amax
         self.mapping_instructions = self.mapping_instructions.copy()
         if custom_mappings is not None:
             self.mapping_instructions.update(custom_mappings)  # OVERRIDE DEFAULT WITH CUSTOM. CUSTOM IS PRIORITIZED
 
     def _get_default_quant_descriptor(self, for_weights=False):
-        if self.default_quant_modules_calib_method in ["percentile", "mse", "entropy"]:
-            calib_method_type = "histogram"
-        else:
-            calib_method_type = "max"
+        methods = {"percentile": "histogram", "mse": "histogram", "entropy": "histogram", "histogram": "histogram", "max": "max"}
+
+        if for_weights:
+            axis = 0 if self.default_per_channel_quant_weights else None
 
-        if self.default_per_channel_quant_modules and for_weights:
-            return QuantDescriptor(calib_method=calib_method_type, axis=0)
-        return QuantDescriptor(calib_method=calib_method_type)
+            learn_amax = self.default_learn_amax
+            if self.default_learn_amax and self.default_per_channel_quant_weights:
+                logger.error("Learnable amax is suported only for per-tensor quantization. Disabling it for weights quantization!")
+                learn_amax = False
+
+            return QuantDescriptor(calib_method=methods[self.default_quant_modules_calib_method_weights], axis=axis, learn_amax=learn_amax)
+        else:
+            # activations stay per-tensor by default
+            return QuantDescriptor(calib_method=methods[self.default_quant_modules_calib_method_inputs], learn_amax=self.default_learn_amax)
 
     def register_skip_quantization(self, *, layer_names: Set[str]):
         self.mapping_instructions.update(
@@ -192,7 +206,18 @@ class SelectiveQuantizer:
 
         # COPY STATE DICT IF NEEDED
         if preserve_state_dict:
-            q_instance.load_state_dict(float_module.state_dict(), strict=True)
+            # quant state dict may have additional parameters for Clip and strict loading will fail
+            # if we find at least one Clip module in q_instance, disable strict loading and hope for the best
+            strict_load = True
+            for k in q_instance.state_dict().keys():
+                if "clip.clip_value_max" in k or "clip.clip_value_min" in k:
+                    strict_load = False
+                    logger.debug(
+                        "Instantiating quant module in non-strict mode leaving Clip parameters non-initilaized. Use QuantizationCalibrator to initialize them."
+                    )
+                    break
+
+            q_instance.load_state_dict(float_module.state_dict(), strict=strict_load)
 
         return q_instance
 
@@ -257,6 +282,7 @@ class SelectiveQuantizer:
 
                 if metadata.action == QuantizedMetadata.ReplacementAction.REPLACE:
                     replace()
+
                 elif metadata.action == QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE:
                     replace()
                     recurse_quantize()
Discard
@@ -412,7 +412,7 @@ class QuantizationUtilityTest(unittest.TestCase):
         module = MyModel()
 
         # TEST
-        q_util = SelectiveQuantizer(default_quant_modules_calib_method="max")
+        q_util = SelectiveQuantizer(default_quant_modules_calib_method_inputs="max", default_quant_modules_calib_method_weights="max")
         q_util.quantize_module(module)
 
         x = torch.rand(1, 3, 32, 32)
@@ -730,7 +730,7 @@ class QuantizationUtilityTest(unittest.TestCase):
                     input_quant_descriptor=QuantDescriptor(calib_method="max"),
                 ),
             },
-            default_per_channel_quant_modules=True,
+            default_per_channel_quant_weights=True,
         )
 
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
@@ -754,8 +754,9 @@ class QuantizationUtilityTest(unittest.TestCase):
             torch.testing.assert_close(y_sg, y_pyquant)
 
     def test_sg_resnet_sg_vanilla_quantization_matches_pytorch_quantization(self):
-
         # SG SELECTIVE QUANTIZATION
+        from super_gradients.training.models.classification_models.resnet import Bottleneck
+
         sq = SelectiveQuantizer(
             custom_mappings={
                 torch.nn.Conv2d: QuantizedMetadata(
@@ -779,16 +780,27 @@ class QuantizationUtilityTest(unittest.TestCase):
                     input_quant_descriptor=QuantDescriptor(calib_method="max"),
                 ),
             },
-            default_per_channel_quant_modules=True,
+            default_per_channel_quant_weights=True,
         )
 
+        # SG registers non-naive QuantBottleneck that will have different behaviour, pop it for testing purposes
+        if Bottleneck in sq.mapping_instructions:
+            sq.mapping_instructions.pop(Bottleneck)
+
         resnet_sg: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
         sq.quantize_module(resnet_sg, preserve_state_dict=True)
 
         # PYTORCH-QUANTIZATION
         quant_desc_input = QuantDescriptor(calib_method="histogram")
+        quant_desc_weights = QuantDescriptor(calib_method="max", axis=0)
+
         quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
+        quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weights)
+
         quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
+        quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weights)
+
+        quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(QuantDescriptor(calib_method="histogram"))
 
         quant_modules.initialize()
         resnet_pyquant: nn.Module = super_gradients.training.models.get("resnet50", pretrained_weights="imagenet", num_classes=1000)
Discard