Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#475 Feature/sg 000 clean start prints

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_clean_start_prints
@@ -18,7 +18,7 @@ try:
 
     _imported_clear_ml_failure = None
 except (ImportError, NameError, ModuleNotFoundError) as import_err:
-    logger.warn("Failed to import deci_lab_client")
+    logger.debug("Failed to import clearml")
     _imported_clear_ml_failure = import_err
 
 
Discard
@@ -65,7 +65,7 @@ def get_data_loader(config_name, dataset_cls, train, dataset_params=None, datalo
     GlobalHydra.instance().clear()
     sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
     dataset_config = os.path.join("dataset_params", config_name)
-    with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir)):
+    with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir), version_base="1.2"):
         # config is relative to a module
         cfg = compose(config_name=normalize_path(dataset_config))
 
Discard
@@ -23,7 +23,7 @@ def get(config_name, overriding_params: Dict = None) -> Dict:
         overriding_params = dict()
     GlobalHydra.instance().clear()
     sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
-    with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir)):
+    with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir), version_base="1.2"):
         cfg = compose(config_name=normalize_path(config_name))
         cfg = hydra.utils.instantiate(cfg)
         training_params = cfg.training_hyperparams
Discard
@@ -32,15 +32,13 @@ def load_experiment_cfg(experiment_name: str, ckpt_root_dir: str = None) -> Dict
 
     resume_dir = Path(checkpoints_dir_path) / ".hydra"
     if not resume_dir.exists():
-        raise FileNotFoundError(
-            f"The checkpoint directory {checkpoints_dir_path} does not include .hydra artifacts to resume the experiment."
-        )
+        raise FileNotFoundError(f"The checkpoint directory {checkpoints_dir_path} does not include .hydra artifacts to resume the experiment.")
 
     # Load overrides that were used in previous run
     overrides_cfg = list(OmegaConf.load(resume_dir / "overrides.yaml"))
 
     GlobalHydra.instance().clear()
-    with initialize_config_dir(config_dir=normalize_path(str(resume_dir))):
+    with initialize_config_dir(config_dir=normalize_path(str(resume_dir)), version_base="1.2"):
         cfg = compose(config_name="config.yaml", overrides=overrides_cfg)
     return cfg
 
Discard
@@ -16,8 +16,7 @@ import os
 from enum import Enum
 from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model
 from super_gradients.training.utils import get_param
-from super_gradients.training.utils.distributed_training_utils import get_local_rank, \
-    get_world_size
+from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
 from torch.distributed import all_gather
 
 logger = get_logger(__name__)
@@ -29,33 +28,32 @@ try:
 
     _imported_pytorch_quantization_failure = None
 except (ImportError, NameError, ModuleNotFoundError) as import_err:
-    logger.warning("Failed to import pytorch_quantization")
+    logger.debug("Failed to import pytorch_quantization")
     _imported_pytorch_quantization_failure = import_err
 
 
 class QuantizationLevel(str, Enum):
-    FP32 = 'FP32'
-    FP16 = 'FP16'
-    INT8 = 'INT8'
-    HYBRID = 'Hybrid'
+    FP32 = "FP32"
+    FP16 = "FP16"
+    INT8 = "INT8"
+    HYBRID = "Hybrid"
 
     @staticmethod
     def from_string(quantization_level: str) -> Enum:
         quantization_level = quantization_level.lower()
-        if quantization_level == 'fp32':
+        if quantization_level == "fp32":
             return QuantizationLevel.FP32
-        elif quantization_level == 'fp16':
+        elif quantization_level == "fp16":
             return QuantizationLevel.FP16
-        elif quantization_level == 'int8':
+        elif quantization_level == "int8":
             return QuantizationLevel.INT8
-        elif quantization_level == 'hybrid':
+        elif quantization_level == "hybrid":
             return QuantizationLevel.HYBRID
         else:
             raise NotImplementedError(f'Quantization Level: "{quantization_level}" is not supported')
 
 
-def export_qat_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple,
-                    per_channel_quantization: bool = False):
+def export_qat_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tuple, per_channel_quantization: bool = False):
     """
     Method for exporting onnx after QAT.
 
@@ -72,15 +70,14 @@ def export_qat_onnx(model: torch.nn.Module, onnx_filename: str, input_shape: tup
         quant_nn.TensorQuantizer.use_fb_fake_quant = True
         # Export ONNX for multiple batch sizes
         logger.info("Creating ONNX file: " + onnx_filename)
-        dummy_input = torch.randn(input_shape, device='cuda')
+        dummy_input = torch.randn(input_shape, device="cuda")
         opset_version = 13 if per_channel_quantization else 12
-        torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version,
-                          enable_onnx_checker=False,
-                          do_constant_folding=True)
+        torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True)
 
 
-def calibrate_model(model: torch.nn.Module, calib_data_loader: torch.utils.data.DataLoader, method: str = "percentile",
-                    num_calib_batches: int = 2, percentile: float = 99.99):
+def calibrate_model(
+    model: torch.nn.Module, calib_data_loader: torch.utils.data.DataLoader, method: str = "percentile", num_calib_batches: int = 2, percentile: float = 99.99
+):
     """
     Calibrates torch model with quantized modules.
 
@@ -106,9 +103,7 @@ def calibrate_model(model: torch.nn.Module, calib_data_loader: torch.utils.data.
             else:
                 _compute_amax(model, method=method)
     else:
-        raise ValueError(
-            "Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(
-                method) + ".")
+        raise ValueError("Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(method) + ".")
 
 
 def _collect_stats(model, data_loader, num_batches):
@@ -125,7 +120,7 @@ def _collect_stats(model, data_loader, num_batches):
         # Feed data to the network for collecting stats
         for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
             if world_size > 1:
-                all_batches = [torch.zeros_like(image, device='cuda') for _ in range(world_size)]
+                all_batches = [torch.zeros_like(image, device="cuda") for _ in range(world_size)]
                 all_gather(all_batches, image.cuda())
             else:
                 all_batches = [image]
@@ -233,9 +228,17 @@ class QATCallback(PhaseCallback):
 
     """
 
-    def __init__(self, start_epoch: int, quant_modules_calib_method: str = "percentile",
-                 per_channel_quant_modules: bool = False, calibrate: bool = True, calibrated_model_path: str = None,
-                 calib_data_loader: DataLoader = None, num_calib_batches: int = 2, percentile: float = 99.99):
+    def __init__(
+        self,
+        start_epoch: int,
+        quant_modules_calib_method: str = "percentile",
+        per_channel_quant_modules: bool = False,
+        calibrate: bool = True,
+        calibrated_model_path: str = None,
+        calib_data_loader: DataLoader = None,
+        num_calib_batches: int = 2,
+        percentile: float = 99.99,
+    ):
         super(QATCallback, self).__init__(Phase.TRAIN_EPOCH_START)
         self._validate_args(start_epoch, quant_modules_calib_method, calibrate, calibrated_model_path)
         self.start_epoch = start_epoch
@@ -254,11 +257,10 @@ class QATCallback(PhaseCallback):
             raise ValueError("start_epoch must be positive.")
         if quant_modules_calib_method not in ["percentile", "mse", "entropy", "max"]:
             raise ValueError(
-                "Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(
-                    self.quant_modules_calib_method) + ".")
+                "Unsupported quantization calibration method, expected one of: percentile, mse, entropy, max, got " + str(self.quant_modules_calib_method) + "."
+            )
         if not calibrate and calibrated_model_path is None:
-            logger.warning(
-                "calibrate=False and no calibrated_model_path is given. QAT will be on an uncalibrated model.")
+            logger.warning("calibrate=False and no calibrated_model_path is given. QAT will be on an uncalibrated model.")
 
     def __call__(self, context: PhaseContext):
         if context.epoch == self.start_epoch:
@@ -273,7 +275,7 @@ class QATCallback(PhaseCallback):
             if self.calibrated_model_path is not None:
                 checkpoint_path = self.calibrated_model_path
             elif self.start_epoch > 0:
-                checkpoint_path = os.path.join(context.ckpt_dir, 'ckpt_best.pth')
+                checkpoint_path = os.path.join(context.ckpt_dir, "ckpt_best.pth")
 
             qat_net = models.get(context.architecture, arch_params=context.arch_params.to_dict(), checkpoint_path=checkpoint_path)
 
@@ -289,7 +291,7 @@ class QATCallback(PhaseCallback):
             context.context_methods._reset_best_metric()
 
             # SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
-            context.context_methods.set_ckpt_best_name('qat_ckpt_best.pth')
+            context.context_methods.set_ckpt_best_name("qat_ckpt_best.pth")
 
             # FINALLY, SET THE QAT NET TO CONTINUE TRAINING
             context.context_methods.set_net(qat_net)
@@ -301,13 +303,16 @@ class QATCallback(PhaseCallback):
         :param context: PhaseContext, current phase context.
         """
         self.calib_data_loader = self.calib_data_loader or context.train_loader
-        calibrate_model(model=context.net,
-                        calib_data_loader=self.calib_data_loader,
-                        method=self.quant_modules_calib_method,
-                        num_calib_batches=self.num_calib_batches,
-                        percentile=self.percentile)
-        method_desc = self.quant_modules_calib_method + '_' + str(
-            self.percentile) if self.quant_modules_calib_method == 'percentile' else self.quant_modules_calib_method
+        calibrate_model(
+            model=context.net,
+            calib_data_loader=self.calib_data_loader,
+            method=self.quant_modules_calib_method,
+            num_calib_batches=self.num_calib_batches,
+            percentile=self.percentile,
+        )
+        method_desc = (
+            self.quant_modules_calib_method + "_" + str(self.percentile) if self.quant_modules_calib_method == "percentile" else self.quant_modules_calib_method
+        )
 
         if not context.ddp_silent_mode:
             logger.info("Performing additional validation on calibrated model...")
@@ -317,11 +322,8 @@ class QATCallback(PhaseCallback):
 
         if not context.ddp_silent_mode:
             logger.info("Calibrate model " + context.metric_to_watch + ": " + str(calibrated_acc))
-            context.sg_logger.add_checkpoint(tag='ckpt_calibrated_' + method_desc + '.pth',
-                                             state_dict={"net": context.net.state_dict(), "acc": calibrated_acc})
-            context.sg_logger.add_scalar("Calibrated_Model_" + context.metric_to_watch,
-                                         calibrated_acc,
-                                         global_step=self.start_epoch)
+            context.sg_logger.add_checkpoint(tag="ckpt_calibrated_" + method_desc + ".pth", state_dict={"net": context.net.state_dict(), "acc": calibrated_acc})
+            context.sg_logger.add_scalar("Calibrated_Model_" + context.metric_to_watch, calibrated_acc, global_step=self.start_epoch)
 
     def _initialize_quant_modules(self):
         """
@@ -332,9 +334,9 @@ class QATCallback(PhaseCallback):
             raise _imported_pytorch_quantization_failure
         else:
             if self.quant_modules_calib_method in ["percentile", "mse", "entropy"]:
-                calib_method_type = 'histogram'
+                calib_method_type = "histogram"
             else:
-                calib_method_type = 'max'
+                calib_method_type = "max"
 
             if self.per_channel_quant_modules:
                 quant_desc_input = QuantDescriptor(calib_method=calib_method_type)
@@ -372,13 +374,14 @@ class PostQATConversionCallback(PhaseCallback):
             best_ckpt_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.pth")
             onnx_path = os.path.join(context.ckpt_dir, "qat_ckpt_best.onnx")
 
-            load_checkpoint_to_model(ckpt_local_path=best_ckpt_path,
-                                     net=context.net,
-                                     load_weights_only=True,
-                                     load_ema_as_net=context.training_params.ema,
-                                     strict=True,
-                                     load_backbone=False
-                                     )
+            load_checkpoint_to_model(
+                ckpt_local_path=best_ckpt_path,
+                net=context.net,
+                load_weights_only=True,
+                load_ema_as_net=context.training_params.ema,
+                strict=True,
+                load_backbone=False,
+            )
             per_channel_quant_modules = get_param(context.training_params.qat_params, "per_channel_quant_modules")
             export_qat_onnx(context.net.module, onnx_path, self.dummy_input_size, per_channel_quant_modules)
 
Discard
@@ -15,7 +15,7 @@ class PPYoloETests(unittest.TestCase):
     def get_model_arch_params(self, config_name):
         GlobalHydra.instance().clear()
         sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
-        with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir)):
+        with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir), version_base="1.2"):
             cfg = compose(config_name=normalize_path(config_name))
             cfg = hydra.utils.instantiate(cfg)
             arch_params = cfg.arch_params
Discard