|
@@ -32,7 +32,6 @@ def register_quantized_module(
|
|
|
"""
|
|
|
|
|
|
def decorator(quant_module: Type[SGQuantMixin]) -> Type[SGQuantMixin]:
|
|
|
-
|
|
|
if float_source in SelectiveQuantizer.mapping_instructions:
|
|
|
metadata = SelectiveQuantizer.mapping_instructions[float_source]
|
|
|
raise ValueError(f"`{float_source}` is already registered with following metadata {metadata}")
|
|
@@ -54,11 +53,12 @@ def register_quantized_module(
|
|
|
|
|
|
|
|
|
class SelectiveQuantizer:
|
|
|
-
|
|
|
"""
|
|
|
- :param custom_mappings: custom mappings that extend the default mappings with extra behaviour
|
|
|
- :param default_quant_modules_calib_method: default calibration method (default='percentile')
|
|
|
- :param default_per_channel_quant_modules: whether quant modules should be per channel (default=False)
|
|
|
+ :param custom_mappings: custom mappings that extend the default mappings with extra behaviour
|
|
|
+ :param default_per_channel_quant_weights: whether quant module weights should be per channel (default=True)
|
|
|
+ :param default_quant_modules_calib_method_weights: default calibration method for weights (default='max')
|
|
|
+ :param default_quant_modules_calib_method_inputs: default calibration method for inputs (default='percentile')
|
|
|
+ :param default_learn_amax: EXPERIMENTAL! whether quant modules should have learnable amax (default=False)
|
|
|
"""
|
|
|
|
|
|
if _imported_pytorch_quantization_failure is not None:
|
|
@@ -93,24 +93,38 @@ class SelectiveQuantizer:
|
|
|
} # DEFAULT MAPPING INSTRUCTIONS
|
|
|
|
|
|
def __init__(
|
|
|
- self, *, custom_mappings: dict = None, default_quant_modules_calib_method: str = "max", default_per_channel_quant_modules: bool = True
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ custom_mappings: dict = None,
|
|
|
+ default_quant_modules_calib_method_weights: str = "max",
|
|
|
+ default_quant_modules_calib_method_inputs: str = "percentile",
|
|
|
+ default_per_channel_quant_weights: bool = True,
|
|
|
+ default_learn_amax: bool = False,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
- self.default_quant_modules_calib_method = default_quant_modules_calib_method
|
|
|
- self.default_per_channel_quant_modules = default_per_channel_quant_modules
|
|
|
+ self.default_quant_modules_calib_method_weights = default_quant_modules_calib_method_weights
|
|
|
+ self.default_quant_modules_calib_method_inputs = default_quant_modules_calib_method_inputs
|
|
|
+ self.default_per_channel_quant_weights = default_per_channel_quant_weights
|
|
|
+ self.default_learn_amax = default_learn_amax
|
|
|
self.mapping_instructions = self.mapping_instructions.copy()
|
|
|
if custom_mappings is not None:
|
|
|
self.mapping_instructions.update(custom_mappings) # OVERRIDE DEFAULT WITH CUSTOM. CUSTOM IS PRIORITIZED
|
|
|
|
|
|
def _get_default_quant_descriptor(self, for_weights=False):
|
|
|
- if self.default_quant_modules_calib_method in ["percentile", "mse", "entropy"]:
|
|
|
- calib_method_type = "histogram"
|
|
|
- else:
|
|
|
- calib_method_type = "max"
|
|
|
+ methods = {"percentile": "histogram", "mse": "histogram", "entropy": "histogram", "histogram": "histogram", "max": "max"}
|
|
|
+
|
|
|
+ if for_weights:
|
|
|
+ axis = 0 if self.default_per_channel_quant_weights else None
|
|
|
|
|
|
- if self.default_per_channel_quant_modules and for_weights:
|
|
|
- return QuantDescriptor(calib_method=calib_method_type, axis=0)
|
|
|
- return QuantDescriptor(calib_method=calib_method_type)
|
|
|
+ learn_amax = self.default_learn_amax
|
|
|
+ if self.default_learn_amax and self.default_per_channel_quant_weights:
|
|
|
+ logger.error("Learnable amax is suported only for per-tensor quantization. Disabling it for weights quantization!")
|
|
|
+ learn_amax = False
|
|
|
+
|
|
|
+ return QuantDescriptor(calib_method=methods[self.default_quant_modules_calib_method_weights], axis=axis, learn_amax=learn_amax)
|
|
|
+ else:
|
|
|
+ # activations stay per-tensor by default
|
|
|
+ return QuantDescriptor(calib_method=methods[self.default_quant_modules_calib_method_inputs], learn_amax=self.default_learn_amax)
|
|
|
|
|
|
def register_skip_quantization(self, *, layer_names: Set[str]):
|
|
|
self.mapping_instructions.update(
|
|
@@ -192,7 +206,18 @@ class SelectiveQuantizer:
|
|
|
|
|
|
# COPY STATE DICT IF NEEDED
|
|
|
if preserve_state_dict:
|
|
|
- q_instance.load_state_dict(float_module.state_dict(), strict=True)
|
|
|
+ # quant state dict may have additional parameters for Clip and strict loading will fail
|
|
|
+ # if we find at least one Clip module in q_instance, disable strict loading and hope for the best
|
|
|
+ strict_load = True
|
|
|
+ for k in q_instance.state_dict().keys():
|
|
|
+ if "clip.clip_value_max" in k or "clip.clip_value_min" in k:
|
|
|
+ strict_load = False
|
|
|
+ logger.debug(
|
|
|
+ "Instantiating quant module in non-strict mode leaving Clip parameters non-initilaized. Use QuantizationCalibrator to initialize them."
|
|
|
+ )
|
|
|
+ break
|
|
|
+
|
|
|
+ q_instance.load_state_dict(float_module.state_dict(), strict=strict_load)
|
|
|
|
|
|
return q_instance
|
|
|
|
|
@@ -257,6 +282,7 @@ class SelectiveQuantizer:
|
|
|
|
|
|
if metadata.action == QuantizedMetadata.ReplacementAction.REPLACE:
|
|
|
replace()
|
|
|
+
|
|
|
elif metadata.action == QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE:
|
|
|
replace()
|
|
|
recurse_quantize()
|