|
@@ -16,8 +16,7 @@ import os
|
|
from enum import Enum
|
|
from enum import Enum
|
|
from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model
|
|
from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model
|
|
from super_gradients.training.utils import get_param
|
|
from super_gradients.training.utils import get_param
|
|
-from super_gradients.training.utils.distributed_training_utils import get_local_rank, \
|
|
|
|
- get_world_size
|
|
|
|
|
|
+from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
|
|
from torch.distributed import all_gather
|
|
from torch.distributed import all_gather
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
@@ -29,33 +28,32 @@ try:
|
|
|
|
|
|
_imported_pytorch_quantization_failure = None
|
|
_imported_pytorch_quantization_failure = None
|
|
except (ImportError, NameError, ModuleNotFoundError) as import_err:
|
|
except (ImportError, NameError, ModuleNotFoundError) as import_err:
|
|
- logger.warning("Failed to import pytorch_quantization")
|
|
|
|
|
|
+ logger.debug("Failed to import pytorch_quantization")
|
|
_imported_pytorch_quantization_failure = import_err
|
|
_imported_pytorch_quantization_failure = import_err
|
|
|
|
|
|
|
|
|
|
class QuantizationLevel(str, Enum):
|
|
class QuantizationLevel(str, Enum):
|
|
- FP32 = 'FP32'
|
|
|
|
- FP16 = 'FP16'
|
|
|
|
- INT8 = 'INT8'
|
|
|
|
- HYBRID = 'Hybrid'
|
|
|
|
|
|
+ FP32 = "FP32"
|
|
|
|
+ FP16 = "FP16"
|
|
|
|
+ INT8 = "INT8"
|
|
|
|
+ HYBRID = "Hybrid"
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def from_string(quantization_level: str) -> Enum:
|
|
def from_string(quantization_level: str) -> Enum:
|
|
quantization_level = quantization_level.lower()
|
|
quantization_level = quantization_level.lower()
|
|
- if quantization_level == 'fp32':
|
|
|
|
|
|
+ if quantization_level == "fp32":
|
|
return QuantizationLevel.FP32
|
|
return QuantizationLevel.FP32
|
|
- elif quantization_level == 'fp16':
|
|
|
|
|
|
+ elif quantization_level == "fp16":
|
|
return QuantizationLevel.FP16
|
|
return QuantizationLevel.FP16
|
|
- elif quantization_level == 'int8':
|
|
|
|
|
|
+ elif quantization_level == "int8":
|
|
return QuantizationLevel.INT8
|
|
return QuantizationLevel.INT8
|
|
- elif quantization_level == 'hybrid':
|
|
|
|
|
|
+ elif quantization_level == "hybrid":
|
|
return QuantizationLevel.HYBRID
|
|
return QuantizationLevel.HYBRID
|
|
else:
|
|
else:
|
|
raise NotImplementedError(f'Quantization Level: "{quantization_level}" is not supported')
|
|
raise NotImplementedError(f'Quantization Level: "{quantization_level}" is not supported')
|
|
|
|
|
|
|
|
|
|
-def export_qat_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple,
|
|
|
|
- per_channel_quantization: bool = False):
|
|
|
|
|
|
+def export_qat_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple, per_channel_quantization: bool = False):
|
|
"""
|
|
"""
|
|
Method for exporting onnx after QAT.
|
|
Method for exporting onnx after QAT.
|
|
|
|
|
|
@@ -72,15 +70,14 @@ def export_qat_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tup
|
|
quant_nn.TensorQuantizer.use_fb_fake_quant = True
|
|
quant_nn.TensorQuantizer.use_fb_fake_quant = True
|
|
# Export ONNX for multiple batch sizes
|
|
# Export ONNX for multiple batch sizes
|
|
logger.info("Creating ONNX file: " + onnx_filename)
|
|
logger.info("Creating ONNX file: " + onnx_filename)
|
|
- dummy_input = torch.randn(input_shape, device='cuda')
|
|
|
|
|
|
+ dummy_input = torch.randn(input_shape, device="cuda")
|
|
opset_version = 13 if per_channel_quantization else 12
|
|
opset_version = 13 if per_channel_quantization else 12
|
|
- torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version,
|
|
|
|
- enable_onnx_checker=False,
|
|
|
|
- do_constant_folding=True)
|
|
|
|
|
|
+ torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True)
|
|
|
|
|
|
|
|
|
|
-def calibrate_model(model: torch.nn.Module, calib_data_loader: torch.utils.data.DataLoader, method: str = "percentile",
|
|
|
|
- num_calib_batches: int = 2, percentile: float = 99.99):
|
|
|
|
|
|
+def calibrate_model(
|
|
|
|
+ model: torch.nn.Module, calib_data_loader: torch.utils.data.DataLoader, method: str = "percentile", num_calib_batches: int = 2, percentile: float = 99.99
|
|
|
|
+):
|
|
"""
|
|
"""
|
|
Calibrates torch model with quantized modules.
|
|
Calibrates torch model with quantized modules.
|
|
|
|
|
|
@@ -106,9 +103,7 @@ def calibrate_model(model: torch.nn.Module, calib_data_loader: torch.utils.data.
|
|
else:
|
|
else:
|
|
_compute_amax(model, method=method)
|
|
_compute_amax(model, method=method)
|
|
else:
|
|
else:
|
|
- raise ValueError(
|
|
|
|
- "Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(
|
|
|
|
- method) + ".")
|
|
|
|
|
|
+ raise ValueError("Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(method) + ".")
|
|
|
|
|
|
|
|
|
|
def _collect_stats(model, data_loader, num_batches):
|
|
def _collect_stats(model, data_loader, num_batches):
|
|
@@ -125,7 +120,7 @@ def _collect_stats(model, data_loader, num_batches):
|
|
# Feed data to the network for collecting stats
|
|
# Feed data to the network for collecting stats
|
|
for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
|
|
for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
|
|
if world_size > 1:
|
|
if world_size > 1:
|
|
- all_batches = [torch.zeros_like(image, device='cuda') for _ in range(world_size)]
|
|
|
|
|
|
+ all_batches = [torch.zeros_like(image, device="cuda") for _ in range(world_size)]
|
|
all_gather(all_batches, image.cuda())
|
|
all_gather(all_batches, image.cuda())
|
|
else:
|
|
else:
|
|
all_batches = [image]
|
|
all_batches = [image]
|
|
@@ -233,9 +228,17 @@ class QATCallback(PhaseCallback):
|
|
|
|
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, start_epoch: int, quant_modules_calib_method: str = "percentile",
|
|
|
|
- per_channel_quant_modules: bool = False, calibrate: bool = True, calibrated_model_path: str = None,
|
|
|
|
- calib_data_loader: DataLoader = None, num_calib_batches: int = 2, percentile: float = 99.99):
|
|
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ start_epoch: int,
|
|
|
|
+ quant_modules_calib_method: str = "percentile",
|
|
|
|
+ per_channel_quant_modules: bool = False,
|
|
|
|
+ calibrate: bool = True,
|
|
|
|
+ calibrated_model_path: str = None,
|
|
|
|
+ calib_data_loader: DataLoader = None,
|
|
|
|
+ num_calib_batches: int = 2,
|
|
|
|
+ percentile: float = 99.99,
|
|
|
|
+ ):
|
|
super(QATCallback, self).__init__(Phase.TRAIN_EPOCH_START)
|
|
super(QATCallback, self).__init__(Phase.TRAIN_EPOCH_START)
|
|
self._validate_args(start_epoch, quant_modules_calib_method, calibrate, calibrated_model_path)
|
|
self._validate_args(start_epoch, quant_modules_calib_method, calibrate, calibrated_model_path)
|
|
self.start_epoch = start_epoch
|
|
self.start_epoch = start_epoch
|
|
@@ -254,11 +257,10 @@ class QATCallback(PhaseCallback):
|
|
raise ValueError("start_epoch must be positive.")
|
|
raise ValueError("start_epoch must be positive.")
|
|
if quant_modules_calib_method not in ["percentile", "mse", "entropy", "max"]:
|
|
if quant_modules_calib_method not in ["percentile", "mse", "entropy", "max"]:
|
|
raise ValueError(
|
|
raise ValueError(
|
|
- "Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(
|
|
|
|
- self.quant_modules_calib_method) + ".")
|
|
|
|
|
|
+ "Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(self.quant_modules_calib_method) + "."
|
|
|
|
+ )
|
|
if not calibrate and calibrated_model_path is None:
|
|
if not calibrate and calibrated_model_path is None:
|
|
- logger.warning(
|
|
|
|
- "calibrate=False and no calibrated_model_path is given. QAT will be on an uncalibrated model.")
|
|
|
|
|
|
+ logger.warning("calibrate=False and no calibrated_model_path is given. QAT will be on an uncalibrated model.")
|
|
|
|
|
|
def __call__(self, context: PhaseContext):
|
|
def __call__(self, context: PhaseContext):
|
|
if context.epoch == self.start_epoch:
|
|
if context.epoch == self.start_epoch:
|
|
@@ -273,7 +275,7 @@ class QATCallback(PhaseCallback):
|
|
if self.calibrated_model_path is not None:
|
|
if self.calibrated_model_path is not None:
|
|
checkpoint_path = self.calibrated_model_path
|
|
checkpoint_path = self.calibrated_model_path
|
|
elif self.start_epoch > 0:
|
|
elif self.start_epoch > 0:
|
|
- checkpoint_path = os.path.join(context.ckpt_dir, 'ckpt_best.pth')
|
|
|
|
|
|
+ checkpoint_path = os.path.join(context.ckpt_dir, "ckpt_best.pth")
|
|
|
|
|
|
qat_net = models.get(context.architecture, arch_params=context.arch_params.to_dict(), checkpoint_path=checkpoint_path)
|
|
qat_net = models.get(context.architecture, arch_params=context.arch_params.to_dict(), checkpoint_path=checkpoint_path)
|
|
|
|
|
|
@@ -289,7 +291,7 @@ class QATCallback(PhaseCallback):
|
|
context.context_methods._reset_best_metric()
|
|
context.context_methods._reset_best_metric()
|
|
|
|
|
|
# SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
|
|
# SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
|
|
- context.context_methods.set_ckpt_best_name('qat_ckpt_best.pth')
|
|
|
|
|
|
+ context.context_methods.set_ckpt_best_name("qat_ckpt_best.pth")
|
|
|
|
|
|
# FINALLY, SET THE QAT NET TO CONTINUE TRAINING
|
|
# FINALLY, SET THE QAT NET TO CONTINUE TRAINING
|
|
context.context_methods.set_net(qat_net)
|
|
context.context_methods.set_net(qat_net)
|
|
@@ -301,13 +303,16 @@ class QATCallback(PhaseCallback):
|
|
:param context: PhaseContext, current phase context.
|
|
:param context: PhaseContext, current phase context.
|
|
"""
|
|
"""
|
|
self.calib_data_loader = self.calib_data_loader or context.train_loader
|
|
self.calib_data_loader = self.calib_data_loader or context.train_loader
|
|
- calibrate_model(model=context.net,
|
|
|
|
- calib_data_loader=self.calib_data_loader,
|
|
|
|
- method=self.quant_modules_calib_method,
|
|
|
|
- num_calib_batches=self.num_calib_batches,
|
|
|
|
- percentile=self.percentile)
|
|
|
|
- method_desc = self.quant_modules_calib_method + '_' + str(
|
|
|
|
- self.percentile) if self.quant_modules_calib_method == 'percentile' else self.quant_modules_calib_method
|
|
|
|
|
|
+ calibrate_model(
|
|
|
|
+ model=context.net,
|
|
|
|
+ calib_data_loader=self.calib_data_loader,
|
|
|
|
+ method=self.quant_modules_calib_method,
|
|
|
|
+ num_calib_batches=self.num_calib_batches,
|
|
|
|
+ percentile=self.percentile,
|
|
|
|
+ )
|
|
|
|
+ method_desc = (
|
|
|
|
+ self.quant_modules_calib_method + "_" + str(self.percentile) if self.quant_modules_calib_method == "percentile" else self.quant_modules_calib_method
|
|
|
|
+ )
|
|
|
|
|
|
if not context.ddp_silent_mode:
|
|
if not context.ddp_silent_mode:
|
|
logger.info("Performing additional validation on calibrated model...")
|
|
logger.info("Performing additional validation on calibrated model...")
|
|
@@ -317,11 +322,8 @@ class QATCallback(PhaseCallback):
|
|
|
|
|
|
if not context.ddp_silent_mode:
|
|
if not context.ddp_silent_mode:
|
|
logger.info("Calibrate model " + context.metric_to_watch + ": " + str(calibrated_acc))
|
|
logger.info("Calibrate model " + context.metric_to_watch + ": " + str(calibrated_acc))
|
|
- context.sg_logger.add_checkpoint(tag='ckpt_calibrated_' + method_desc + '.pth',
|
|
|
|
- state_dict={"net": context.net.state_dict(), "acc": calibrated_acc})
|
|
|
|
- context.sg_logger.add_scalar("Calibrated_Model_" + context.metric_to_watch,
|
|
|
|
- calibrated_acc,
|
|
|
|
- global_step=self.start_epoch)
|
|
|
|
|
|
+ context.sg_logger.add_checkpoint(tag="ckpt_calibrated_" + method_desc + ".pth", state_dict={"net": context.net.state_dict(), "acc": calibrated_acc})
|
|
|
|
+ context.sg_logger.add_scalar("Calibrated_Model_" + context.metric_to_watch, calibrated_acc, global_step=self.start_epoch)
|
|
|
|
|
|
def _initialize_quant_modules(self):
|
|
def _initialize_quant_modules(self):
|
|
"""
|
|
"""
|
|
@@ -332,9 +334,9 @@ class QATCallback(PhaseCallback):
|
|
raise _imported_pytorch_quantization_failure
|
|
raise _imported_pytorch_quantization_failure
|
|
else:
|
|
else:
|
|
if self.quant_modules_calib_method in ["percentile", "mse", "entropy"]:
|
|
if self.quant_modules_calib_method in ["percentile", "mse", "entropy"]:
|
|
- calib_method_type = 'histogram'
|
|
|
|
|
|
+ calib_method_type = "histogram"
|
|
else:
|
|
else:
|
|
- calib_method_type = 'max'
|
|
|
|
|
|
+ calib_method_type = "max"
|
|
|
|
|
|
if self.per_channel_quant_modules:
|
|
if self.per_channel_quant_modules:
|
|
quant_desc_input = QuantDescriptor(calib_method=calib_method_type)
|
|
quant_desc_input = QuantDescriptor(calib_method=calib_method_type)
|
|
@@ -372,13 +374,14 @@ class PostQATConversionCallback(PhaseCallback):
|
|
best_ckpt_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.pth")
|
|
best_ckpt_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.pth")
|
|
onnx_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.onnx")
|
|
onnx_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.onnx")
|
|
|
|
|
|
- load_checkpoint_to_model(ckpt_local_path=best_ckpt_path,
|
|
|
|
- net=context.net,
|
|
|
|
- load_weights_only=True,
|
|
|
|
- load_ema_as_net=context.training_params.ema,
|
|
|
|
- strict=True,
|
|
|
|
- load_backbone=False
|
|
|
|
- )
|
|
|
|
|
|
+ load_checkpoint_to_model(
|
|
|
|
+ ckpt_local_path=best_ckpt_path,
|
|
|
|
+ net=context.net,
|
|
|
|
+ load_weights_only=True,
|
|
|
|
+ load_ema_as_net=context.training_params.ema,
|
|
|
|
+ strict=True,
|
|
|
|
+ load_backbone=False,
|
|
|
|
+ )
|
|
per_channel_quant_modules = get_param(context.training_params.qat_params, "per_channel_quant_modules")
|
|
per_channel_quant_modules = get_param(context.training_params.qat_params, "per_channel_quant_modules")
|
|
export_qat_onnx(context.net.module, onnx_path, self.dummy_input_size, per_channel_quant_modules)
|
|
export_qat_onnx(context.net.module, onnx_path, self.dummy_input_size, per_channel_quant_modules)
|
|
|
|
|