Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#289 Fix flake8 errors in super_gradients/training

Merged
Ofri Masad merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-fix_flake8_errors_training
25 changed files with 98 additions and 89 deletions
  1. 6
    5
      src/super_gradients/training/__init__.py
  2. 2
    1
      src/super_gradients/training/datasets/data_augmentation.py
  3. 5
    7
      src/super_gradients/training/datasets/dataset_interfaces/__init__.py
  4. 5
    3
      src/super_gradients/training/datasets/datasets_utils.py
  5. 2
    1
      src/super_gradients/training/datasets/detection_datasets/__init__.py
  6. 6
    1
      src/super_gradients/training/datasets/segmentation_datasets/__init__.py
  7. 5
    8
      src/super_gradients/training/exceptions/kd_model_exceptions.py
  8. 0
    2
      src/super_gradients/training/kd_model/kd_model.py
  9. 0
    1
      src/super_gradients/training/losses/ssd_loss.py
  10. 4
    1
      src/super_gradients/training/losses/stdc_loss.py
  11. 9
    5
      src/super_gradients/training/losses/yolox_loss.py
  12. 6
    3
      src/super_gradients/training/metrics/__init__.py
  13. 2
    1
      src/super_gradients/training/models/classification_models/mobilenetv3.py
  14. 1
    1
      src/super_gradients/training/models/detection_models/yolo_base.py
  15. 2
    1
      src/super_gradients/training/models/segmentation_models/laddernet.py
  16. 4
    3
      src/super_gradients/training/sg_model/sg_model.py
  17. 0
    3
      src/super_gradients/training/trainer.py
  18. 1
    2
      src/super_gradients/training/transforms/transforms.py
  19. 11
    23
      src/super_gradients/training/utils/callbacks.py
  20. 13
    5
      src/super_gradients/training/utils/checkpoint_utils.py
  21. 3
    1
      src/super_gradients/training/utils/detection_utils.py
  22. 0
    1
      src/super_gradients/training/utils/export_utils.py
  23. 4
    3
      src/super_gradients/training/utils/get_model_stats.py
  24. 7
    6
      src/super_gradients/training/utils/quantization_utils.py
  25. 0
    1
      src/super_gradients/training/utils/ssd_utils.py
@@ -1,10 +1,11 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
 import super_gradients.training.utils.distributed_training_utils as distributed_training_utils
-from super_gradients.training.datasets import datasets_utils, DataAugmentation, DetectionDataSet, TestDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface
+from super_gradients.training.datasets import datasets_utils, DataAugmentation, DetectionDataSet, TestDatasetInterface, SegmentationTestDatasetInterface,\
+    DetectionTestDatasetInterface, ClassificationTestDatasetInterface
 from super_gradients.training.models import ARCHITECTURES
 from super_gradients.training.models import ARCHITECTURES
-from super_gradients.training.sg_model import SgModel, \
-    MultiGPUMode, StrictLoad
+from super_gradients.training.sg_model import SgModel, MultiGPUMode, StrictLoad
 from super_gradients.training.kd_model import KDModel
 from super_gradients.training.kd_model import KDModel
 
 
-__all__ = ['distributed_training_utils', 'datasets_utils', 'DataAugmentation', 'DetectionDataSet', 'TestDatasetInterface',
-           'ARCHITECTURES', 'SgModel', 'KDModel', 'MultiGPUMode', 'TestDatasetInterface', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface', 'StrictLoad']
+__all__ = ['distributed_training_utils', 'datasets_utils', 'DataAugmentation', 'DetectionDataSet', 'TestDatasetInterface', 'ARCHITECTURES', 'SgModel',
+           'KDModel', 'MultiGPUMode', 'TestDatasetInterface', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface',
+           'ClassificationTestDatasetInterface', 'StrictLoad']
Discard
@@ -74,7 +74,8 @@ IMAGENET_PCA = {
 class Lighting(object):
 class Lighting(object):
     """
     """
     Lighting noise(AlexNet - style PCA - based noise)
     Lighting noise(AlexNet - style PCA - based noise)
-    Taken from fastai Imagenet training - https://github.com/fastai/imagenet-fast/blob/faa0f9dfc9e8e058ffd07a248724bf384f526fae/imagenet_nv/fastai_imagenet.py#L103
+    Taken from fastai Imagenet training -
+    https://github.com/fastai/imagenet-fast/blob/faa0f9dfc9e8e058ffd07a248724bf384f526fae/imagenet_nv/fastai_imagenet.py#L103
     To use:
     To use:
         - training_params = {"imagenet_pca_aug": 0.1}
         - training_params = {"imagenet_pca_aug": 0.1}
         - Default training_params arg is 0.0 ("don't use")
         - Default training_params arg is 0.0 ("don't use")
Discard
@@ -6,10 +6,8 @@ from super_gradients.training.datasets.dataset_interfaces.dataset_interface impo
     PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface, \
     PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface, \
     TestYoloDetectionDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface
     TestYoloDetectionDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface
 
 
-__all__ = ['DatasetInterface', 'TestDatasetInterface', 'LibraryDatasetInterface', 'ClassificationDatasetInterface',
-           'Cifar10DatasetInterface',
-           'Cifar100DatasetInterface', 'ImageNetDatasetInterface', 'TinyImageNetDatasetInterface',
-           'CoCoDetectionDatasetInterface',
-           'CoCo2014DetectionDatasetInterface', 'CoCoSegmentationDatasetInterface',
-           'PascalAUG2012SegmentationDataSetInterface',
-           'PascalVOC2012SegmentationDataSetInterface', 'TestYoloDetectionDatasetInterface', 'SegmentationTestDatasetInterface', 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface']
+__all__ = ['DatasetInterface', 'TestDatasetInterface', 'LibraryDatasetInterface', 'ClassificationDatasetInterface', 'Cifar10DatasetInterface',
+           'Cifar100DatasetInterface', 'ImageNetDatasetInterface', 'TinyImageNetDatasetInterface', 'CoCoDetectionDatasetInterface',
+           'CoCo2014DetectionDatasetInterface', 'CoCoSegmentationDatasetInterface', 'PascalAUG2012SegmentationDataSetInterface',
+           'PascalVOC2012SegmentationDataSetInterface', 'TestYoloDetectionDatasetInterface', 'SegmentationTestDatasetInterface',
+           'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface']
Discard
@@ -17,7 +17,6 @@ from super_gradients.training.datasets.detection_datasets.detection_dataset impo
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from deprecated import deprecated
 from deprecated import deprecated
 from matplotlib.patches import Rectangle
 from matplotlib.patches import Rectangle
-from torch.utils.tensorboard import SummaryWriter
 from torchvision.datasets import ImageFolder
 from torchvision.datasets import ImageFolder
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
 from super_gradients.training.datasets.auto_augment import rand_augment_transform
 from torchvision.transforms import transforms, InterpolationMode, RandomResizedCrop
 from torchvision.transforms import transforms, InterpolationMode, RandomResizedCrop
@@ -658,8 +657,11 @@ class DatasetStatisticsTensorboardLogger:
 
 
 def get_color_augmentation(rand_augment_config_string: str, color_jitter: tuple, crop_size=224, img_mean=[0.485, 0.456, 0.406]):
 def get_color_augmentation(rand_augment_config_string: str, color_jitter: tuple, crop_size=224, img_mean=[0.485, 0.456, 0.406]):
     """
     """
-    Returns color augmentation class. As these augmentation cannot work on top one another, only one is returned according to rand_augment_config_string
-    :param rand_augment_config_string: string which defines the auto augment configurations. If none, color jitter will be returned. For possibile values see auto_augment.py
+    Returns color augmentation class. As these augmentation cannot work on top one another, only one is returned
+    according to rand_augment_config_string
+
+    :param rand_augment_config_string: string which defines the auto augment configurations.
+                                       If none, color jitter will be returned. For possibile values see auto_augment.py
     :param color_jitter: tuple for color jitter value.
     :param color_jitter: tuple for color jitter value.
     :param crop_size: relevant only for auto augment
     :param crop_size: relevant only for auto augment
     :param img_mean: relevant only for auto augment
     :param img_mean: relevant only for auto augment
Discard
@@ -1,5 +1,6 @@
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataSet
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataSet
 from super_gradients.training.datasets.detection_datasets.coco_detection import COCODetectionDataSet
 from super_gradients.training.datasets.detection_datasets.coco_detection import COCODetectionDataSet
+from super_gradients.training.datasets.detection_datasets.coco_detection_yolox import COCODetectionDatasetV2
 from super_gradients.training.datasets.detection_datasets.pascal_voc_detection import PascalVOCDetectionDataSet
 from super_gradients.training.datasets.detection_datasets.pascal_voc_detection import PascalVOCDetectionDataSet
 
 
-__all__ = ['DetectionDataSet', 'COCODetectionDataSet', 'PascalVOCDetectionDataSet']
+__all__ = ['DetectionDataSet', 'COCODetectionDataSet', 'PascalVOCDetectionDataSet', 'COCODetectionDatasetV2']
Discard
@@ -1,4 +1,9 @@
-from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
+from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
 from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation import PascalAUG2012SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation import PascalAUG2012SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation import PascalVOC2012SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation import PascalVOC2012SegmentationDataSet
+from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
+from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import SuperviselyPersonsDataset
+
+__all__ = ['SegmentationDataSet', 'CoCoSegmentationDataSet', 'PascalAUG2012SegmentationDataSet',
+           'PascalVOC2012SegmentationDataSet', 'CityscapesDataset', 'SuperviselyPersonsDataset']
Discard
@@ -43,8 +43,8 @@ class InconsistentParamsException(KDModelException):
 
 
     def __init__(self, inconsistent_key1: str, inconsistent_key1_container_name: str, inconsistent_key2: str,
     def __init__(self, inconsistent_key1: str, inconsistent_key1_container_name: str, inconsistent_key2: str,
                  inconsistent_key2_container_name: str, ):
                  inconsistent_key2_container_name: str, ):
-        super().__init__(
-            inconsistent_key1 + " in " + inconsistent_key1_container_name + " must be equal to " + inconsistent_key2 + " in " + inconsistent_key2_container_name)
+        super().__init__(f"{inconsistent_key1} in {inconsistent_key1_container_name} must be equal to "
+                         f"{inconsistent_key2} in {inconsistent_key2_container_name}")
 
 
 
 
 class UnsupportedKDModelArgException(KDModelException):
 class UnsupportedKDModelArgException(KDModelException):
@@ -55,8 +55,7 @@ class UnsupportedKDModelArgException(KDModelException):
     """
     """
 
 
     def __init__(self, param_name: str, dict_name: str):
     def __init__(self, param_name: str, dict_name: str):
-        super().__init__(
-            param_name + " in " + dict_name + " not supported for KD models.")
+        super().__init__(param_name + " in " + dict_name + " not supported for KD models.")
 
 
 
 
 class TeacherKnowledgeException(KDModelException):
 class TeacherKnowledgeException(KDModelException):
@@ -67,8 +66,7 @@ class TeacherKnowledgeException(KDModelException):
     """
     """
 
 
     def __init__(self):
     def __init__(self):
-        super().__init__(
-            "Expected: at least one of: teacher_pretrained_weights, teacher_checkpoint_path or load_kd_model_checkpoint=True")
+        super().__init__("Expected: at least one of: teacher_pretrained_weights, teacher_checkpoint_path or load_kd_model_checkpoint=True")
 
 
 
 
 class UndefinedNumClassesException(KDModelException):
 class UndefinedNumClassesException(KDModelException):
@@ -78,5 +76,4 @@ class UndefinedNumClassesException(KDModelException):
         message -- explanation of the error
         message -- explanation of the error
     """
     """
     def __init__(self):
     def __init__(self):
-        super().__init__(
-            'Number of classes must be defined in students and teachers arch params or by connecting to a dataset interface')
+        super().__init__("Number of classes must be defined in students and teachers arch params or by connecting to a dataset interface")
Discard
@@ -273,5 +273,3 @@ class KDModel(SgModel):
 
 
         state["net"] = best_net.state_dict()
         state["net"] = best_net.state_dict()
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
-
-
Discard
@@ -3,7 +3,6 @@ from typing import Tuple
 import torch
 import torch
 from torch import nn
 from torch import nn
 from torch.nn.modules.loss import _Loss
 from torch.nn.modules.loss import _Loss
-import torch.nn.functional as F
 
 
 from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
 from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
 from super_gradients.training.utils.ssd_utils import DefaultBoxes
 from super_gradients.training.utils.ssd_utils import DefaultBoxes
Discard
@@ -157,7 +157,10 @@ class STDCLoss(_Loss):
             self.detail_loss = DetailLoss(weights=detail_weights)
             self.detail_loss = DetailLoss(weights=detail_weights)
 
 
         if ohem_criteria is None:
         if ohem_criteria is None:
-            ohem_criteria = OhemCELoss(threshold=threshold, mining_percent=mining_percent, ignore_lb=ignore_index) if num_classes > 1 else OhemBCELoss(threshold=threshold, mining_percent=mining_percent)
+            if num_classes > 1:
+                ohem_criteria = OhemCELoss(threshold=threshold, mining_percent=mining_percent, ignore_lb=ignore_index)
+            else:
+                ohem_criteria = OhemBCELoss(threshold=threshold, mining_percent=mining_percent)
 
 
         self.ce_ohem = ohem_criteria
         self.ce_ohem = ohem_criteria
         self.num_classes = num_classes
         self.num_classes = num_classes
Discard
@@ -473,10 +473,14 @@ class YoloXDetectionLoss(_Loss):
         # FIND CELL CENTERS THAT ARE WITHIN +- self.center_sampling_radius CELLS FROM GROUND TRUTH BOXES CENTERS
         # FIND CELL CENTERS THAT ARE WITHIN +- self.center_sampling_radius CELLS FROM GROUND TRUTH BOXES CENTERS
 
 
         # define fake boxes: instead of ground truth boxes step +- self.center_sampling_radius from their centers
         # define fake boxes: instead of ground truth boxes step +- self.center_sampling_radius from their centers
-        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)
-        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)
-        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)
-        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)
+        gt_bboxes_per_image_l = ((gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) -
+                                 self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0))
+        gt_bboxes_per_image_r = ((gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) +
+                                 self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0))
+        gt_bboxes_per_image_t = ((gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) -
+                                 self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0))
+        gt_bboxes_per_image_b = ((gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) +
+                                 self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0))
 
 
         c_l = x_centers_per_image - gt_bboxes_per_image_l
         c_l = x_centers_per_image - gt_bboxes_per_image_l
         c_r = gt_bboxes_per_image_r - x_centers_per_image
         c_r = gt_bboxes_per_image_r - x_centers_per_image
@@ -517,7 +521,7 @@ class YoloXDetectionLoss(_Loss):
         for gt_idx in range(num_gt):
         for gt_idx in range(num_gt):
             try:
             try:
                 _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx], largest=False)
                 _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx], largest=False)
-            except:
+            except Exception:
                 logger.warning("cost[gt_idx]: " + str(cost[gt_idx]) + " dynamic_ks[gt_idx]L " + str(dynamic_ks[gt_idx]))
                 logger.warning("cost[gt_idx]: " + str(cost[gt_idx]) + " dynamic_ks[gt_idx]L " + str(dynamic_ks[gt_idx]))
             matching_matrix[gt_idx][pos_idx] = 1
             matching_matrix[gt_idx][pos_idx] = 1
 
 
Discard
@@ -1,6 +1,9 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 
 
-from super_gradients.training.metrics.classification_metrics import *
-from super_gradients.training.metrics.segmentation_metrics import *
+from super_gradients.training.metrics.classification_metrics import accuracy, Accuracy, Top5, ToyTestClassificationMetric
+from super_gradients.training.metrics.detection_metrics import DetectionMetrics
+from super_gradients.training.metrics.segmentation_metrics import PreprocessSegmentationMetricsArgs, PixelAccuracy, IoU, Dice, BinaryIOU, BinaryDice
 
 
-from super_gradients.training.metrics.detection_metrics import *
+
+__all__ = ['accuracy', 'Accuracy', 'Top5', 'ToyTestClassificationMetric', 'DetectionMetrics', 'PreprocessSegmentationMetricsArgs', 'PixelAccuracy', 'IoU',
+           'Dice', 'BinaryIOU', 'BinaryDice']
Discard
@@ -1,6 +1,7 @@
 """
 """
 Creates a MobileNetV3 Model as defined in:
 Creates a MobileNetV3 Model as defined in:
-Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
+Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu,
+Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
 Searching for MobileNetV3
 Searching for MobileNetV3
 arXiv preprint arXiv:1905.02244.
 arXiv preprint arXiv:1905.02244.
 """
 """
Discard
@@ -284,7 +284,7 @@ class YoLoHead(nn.Module):
                                                                                  arch_params.depth_mult_factor)
                                                                                  arch_params.depth_mult_factor)
 
 
         backbone_connector = [width_mult(c) if arch_params.scaled_backbone_width else c
         backbone_connector = [width_mult(c) if arch_params.scaled_backbone_width else c
-                             for c in arch_params.backbone_connection_channels]
+                              for c in arch_params.backbone_connection_channels]
 
 
         DownConv = GroupedConvBlock if depthwise else Conv
         DownConv = GroupedConvBlock if depthwise else Conv
 
 
Discard
@@ -81,7 +81,8 @@ class LadderResNet(nn.Module):
 
 
     Reference:
     Reference:
 
 
-        - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
+        - He, Kaiming, et al. "Deep residual learning for image recognition."
+            Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
 
 
         - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
         - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
     """
     """
Discard
@@ -425,7 +425,7 @@ class SgModel:
 
 
             # TODO: ITERATE BY MAX ITERS
             # TODO: ITERATE BY MAX ITERS
             # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
             # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
-            if self._infinite_train_loader and batch_idx == len(self.train_loader)-1:
+            if self._infinite_train_loader and batch_idx == len(self.train_loader) - 1:
                 break
                 break
 
 
         if not self.ddp_silent_mode:
         if not self.ddp_silent_mode:
@@ -1030,7 +1030,8 @@ class SgModel:
 
 
                 # IN DDP- SET_EPOCH WILL CAUSE EVERY PROCESS TO BE EXPOSED TO THE ENTIRE DATASET BY SHUFFLING WITH A
                 # IN DDP- SET_EPOCH WILL CAUSE EVERY PROCESS TO BE EXPOSED TO THE ENTIRE DATASET BY SHUFFLING WITH A
                 # DIFFERENT SEED EACH EPOCH START
                 # DIFFERENT SEED EACH EPOCH START
-                if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"):
+                if self.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and hasattr(self.train_loader, "sampler")\
+                        and hasattr(self.train_loader.sampler, "set_epoch"):
                     self.train_loader.sampler.set_epoch(epoch)
                     self.train_loader.sampler.set_epoch(epoch)
 
 
                 train_metrics_tuple = self._train_epoch(epoch=epoch, silent_mode=silent_mode)
                 train_metrics_tuple = self._train_epoch(epoch=epoch, silent_mode=silent_mode)
@@ -1759,7 +1760,7 @@ class SgModel:
         self.valid_monitored_values = sg_model_utils.update_monitored_values_dict(
         self.valid_monitored_values = sg_model_utils.update_monitored_values_dict(
             monitored_values_dict=self.valid_monitored_values, new_values_dict=pbar_message_dict)
             monitored_values_dict=self.valid_monitored_values, new_values_dict=pbar_message_dict)
 
 
-        if not silent_mode and evaluation_type==EvaluationType.VALIDATION:
+        if not silent_mode and evaluation_type == EvaluationType.VALIDATION:
             progress_bar_data_loader.write("===========================================================")
             progress_bar_data_loader.write("===========================================================")
             sg_model_utils.display_epoch_summary(epoch=context.epoch, n_digits=4,
             sg_model_utils.display_epoch_summary(epoch=context.epoch, n_digits=4,
                                                  train_monitored_values=self.train_monitored_values,
                                                  train_monitored_values=self.train_monitored_values,
Discard
@@ -1,8 +1,5 @@
 from omegaconf import DictConfig
 from omegaconf import DictConfig
 import hydra
 import hydra
-from super_gradients.training.sg_model import MultiGPUMode
-from super_gradients.common.abstractions.abstract_logger import get_logger
-import torch
 
 
 
 
 class Trainer:
 class Trainer:
Discard
@@ -410,7 +410,7 @@ class DetectionMosaic(DetectionTransform):
             all_samples = [sample] + sample["additional_samples"]
             all_samples = [sample] + sample["additional_samples"]
 
 
             for i_mosaic, mosaic_sample in enumerate(all_samples):
             for i_mosaic, mosaic_sample in enumerate(all_samples):
-                img, _labels,  = mosaic_sample["image"], mosaic_sample["target"]
+                img, _labels = mosaic_sample["image"], mosaic_sample["target"]
                 _labels_seg = mosaic_sample.get("target_seg")
                 _labels_seg = mosaic_sample.get("target_seg")
 
 
                 h0, w0 = img.shape[:2]  # orig hw
                 h0, w0 = img.shape[:2]  # orig hw
@@ -577,7 +577,6 @@ class DetectionMixup(DetectionTransform):
             )
             )
             cp_scale_ratio *= jit_factor
             cp_scale_ratio *= jit_factor
 
 
-
             origin_h, origin_w = cp_img.shape[:2]
             origin_h, origin_w = cp_img.shape[:2]
             target_h, target_w = origin_img.shape[:2]
             target_h, target_w = origin_img.shape[:2]
             padded_img = np.zeros(
             padded_img = np.zeros(
Discard
@@ -237,7 +237,7 @@ class DeciLabUploadCallback(PhaseCallback):
 
 
     @staticmethod
     @staticmethod
     def log_optimization_failed():
     def log_optimization_failed():
-        logger.info(f"We couldn't finish your model optimization. Visit https://console.deci.ai for details")
+        logger.info("We couldn't finish your model optimization. Visit https://console.deci.ai for details")
 
 
     def upload_model(self, model):
     def upload_model(self, model):
         """
         """
@@ -304,10 +304,10 @@ class DeciLabUploadCallback(PhaseCallback):
             logger.info(f"Successfully added {model_name} to the model repository")
             logger.info(f"Successfully added {model_name} to the model repository")
 
 
             optimized_model_name = f"{model_name}_1_1"
             optimized_model_name = f"{model_name}_1_1"
-            logger.info(f"We'll wait for the scheduled optimization to finish. Please don't close this window")
+            logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
             success = self.get_optimization_status(optimized_model_name=optimized_model_name)
             success = self.get_optimization_status(optimized_model_name=optimized_model_name)
             if success:
             if success:
-                logger.info(f"Successfully finished your model optimization. Visit https://console.deci.ai for details")
+                logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
             else:
             else:
                 DeciLabUploadCallback.log_optimization_failed()
                 DeciLabUploadCallback.log_optimization_failed()
         except Exception as ex:
         except Exception as ex:
@@ -460,11 +460,8 @@ class ExponentialLRCallback(LRCallbackBase):
         self.update_lr(context.optimizer, context.epoch, context.batch_idx)
         self.update_lr(context.optimizer, context.epoch, context.batch_idx)
 
 
     def is_lr_scheduling_enabled(self, context):
     def is_lr_scheduling_enabled(self, context):
-        return (
-            self.training_params.lr_warmup_epochs
-            <= context.epoch
-            < self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
-        )
+        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
+        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
 
 
 
 
 class PolyLRCallback(LRCallbackBase):
 class PolyLRCallback(LRCallbackBase):
@@ -489,11 +486,8 @@ class PolyLRCallback(LRCallbackBase):
         self.update_lr(context.optimizer, context.epoch, context.batch_idx)
         self.update_lr(context.optimizer, context.epoch, context.batch_idx)
 
 
     def is_lr_scheduling_enabled(self, context):
     def is_lr_scheduling_enabled(self, context):
-        return (
-            self.training_params.lr_warmup_epochs
-            <= context.epoch
-            < self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
-        )
+        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
+        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
 
 
 
 
 class CosineLRCallback(LRCallbackBase):
 class CosineLRCallback(LRCallbackBase):
@@ -519,11 +513,8 @@ class CosineLRCallback(LRCallbackBase):
         self.update_lr(context.optimizer, context.epoch, context.batch_idx)
         self.update_lr(context.optimizer, context.epoch, context.batch_idx)
 
 
     def is_lr_scheduling_enabled(self, context):
     def is_lr_scheduling_enabled(self, context):
-        return (
-            self.training_params.lr_warmup_epochs
-            <= context.epoch
-            < self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
-        )
+        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
+        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
 
 
 
 
 class FunctionLRCallback(LRCallbackBase):
 class FunctionLRCallback(LRCallbackBase):
@@ -538,11 +529,8 @@ class FunctionLRCallback(LRCallbackBase):
         self.max_epochs = max_epochs
         self.max_epochs = max_epochs
 
 
     def is_lr_scheduling_enabled(self, context):
     def is_lr_scheduling_enabled(self, context):
-        return (
-            self.training_params.lr_warmup_epochs
-            <= context.epoch
-            < self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
-        )
+        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
+        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
 
 
     def perform_scheduling(self, context):
     def perform_scheduling(self, context):
         effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
         effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
Discard
@@ -10,7 +10,8 @@ except (ModuleNotFoundError, ImportError, NameError):
     from torch.hub import _download_url_to_file as download_url_to_file
     from torch.hub import _download_url_to_file as download_url_to_file
 
 
 
 
-def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str, overwrite_local_checkpoint: bool, load_weights_only: bool):
+def get_ckpt_local_path(source_ckpt_folder_name: str, experiment_name: str, ckpt_name: str, model_checkpoints_location: str, external_checkpoint_path: str,
+                        overwrite_local_checkpoint: bool, load_weights_only: bool):
     """
     """
     Gets the local path to the checkpoint file, which will be:
     Gets the local path to the checkpoint file, which will be:
         - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
         - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.
@@ -132,7 +133,8 @@ def read_ckpt_state_dict(ckpt_path: str, device="cpu"):
     return state_dict
     return state_dict
 
 
 
 
-def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: callable=None):
+def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict,
+                                              exclude: list = [], solver: callable = None):
     """
     """
     Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
     Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
     the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
     the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
@@ -174,7 +176,8 @@ def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
         raise RuntimeError(exception_msg)
         raise RuntimeError(exception_msg)
 
 
 
 
-def load_checkpoint_to_model(ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str, load_weights_only: bool, load_ema_as_net: bool = False):
+def load_checkpoint_to_model(ckpt_local_path: str, load_backbone: bool, net: torch.nn.Module, strict: str,
+                             load_weights_only: bool, load_ema_as_net: bool = False):
     """
     """
     Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
     Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
 
 
@@ -226,12 +229,14 @@ class MissingPretrainedWeightsException(Exception):
         self.message = "Missing pretrained wights: " + desc
         self.message = "Missing pretrained wights: " + desc
         super().__init__(self.message)
         super().__init__(self.message)
 
 
+
 def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
 def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
     """
     """
     Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
     Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
     """
     """
 
 
-    if ckpt_val.shape != model_val.shape and ckpt_key == 'module._backbone._modules_list.0.conv.conv.weight' and model_key == '_backbone._modules_list.0.conv.weight':
+    if ckpt_val.shape != model_val.shape and ckpt_key == 'module._backbone._modules_list.0.conv.conv.weight' and \
+            model_key == '_backbone._modules_list.0.conv.weight':
         model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
         model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
         model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
         model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
         model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
         model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
@@ -242,6 +247,7 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
 
 
     return replacement
     return replacement
 
 
+
 def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
 def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
 
 
     """
     """
@@ -262,5 +268,7 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
     if 'ema_net' in pretrained_state_dict.keys():
     if 'ema_net' in pretrained_state_dict.keys():
         pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
         pretrained_state_dict['net'] = pretrained_state_dict['ema_net']
     solver = _yolox_ckpt_solver if "yolox" in architecture else None
     solver = _yolox_ckpt_solver if "yolox" in architecture else None
-    adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(model_state_dict=model.state_dict(), source_ckpt=pretrained_state_dict, solver=solver)
+    adapted_pretrained_state_dict = adapt_state_dict_to_fit_model_layer_names(model_state_dict=model.state_dict(),
+                                                                              source_ckpt=pretrained_state_dict,
+                                                                              solver=solver)
     model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
     model.load_state_dict(adapted_pretrained_state_dict['net'], strict=False)
Discard
@@ -910,7 +910,8 @@ class DetectionVisualization:
             targets_cur = targets[targets[:, 0] == i]
             targets_cur = targets[targets[:, 0] == i]
 
 
             image_name = '_'.join([str(batch_name), str(i)])
             image_name = '_'.join([str(batch_name), str(i)])
-            res_image = DetectionVisualization._visualize_image(image_np[i], preds, targets_cur, class_names, box_thickness, gt_alpha, image_scale, checkpoint_dir, image_name)
+            res_image = DetectionVisualization._visualize_image(image_np[i], preds, targets_cur, class_names, box_thickness, gt_alpha, image_scale,
+                                                                checkpoint_dir, image_name)
             if res_image is not None:
             if res_image is not None:
                 out_images.append(res_image)
                 out_images.append(res_image)
 
 
@@ -1085,6 +1086,7 @@ class CrowdDetectionCollateFN(DetectionCollateFN):
         ims, targets, crowd_targets = batch[0:3]
         ims, targets, crowd_targets = batch[0:3]
         return ims, self._format_targets(targets), {"crowd_targets": self._format_targets(crowd_targets)}
         return ims, self._format_targets(targets), {"crowd_targets": self._format_targets(crowd_targets)}
 
 
+
 def compute_box_area(box: torch.Tensor) -> torch.Tensor:
 def compute_box_area(box: torch.Tensor) -> torch.Tensor:
     """Compute the area of one or many boxes.
     """Compute the area of one or many boxes.
          :param box: One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)
          :param box: One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)
Discard
@@ -1,6 +1,5 @@
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
-import torch.nn.functional as F
 
 
 
 
 def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False):
 def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False):
Discard
@@ -136,7 +136,8 @@ def _convert_summary_dict_to_string(summary: dict, high_verbosity: bool, input_d
     trainable_params = 0
     trainable_params = 0
     if high_verbosity:
     if high_verbosity:
         summary_str += f"{'-' * 200}\n"
         summary_str += f"{'-' * 200}\n"
-        line_new = f'{"block (type)":>20} {"Layer (type)":>20} {"Output Shape":>63} {"Param #":>15} {"inference time[ms]":>25} {"gpu_cached_memory[GB]":>25} {"gpu_occupation[GB]":>25}'
+        line_new = f'{"block (type)":>20} {"Layer (type)":>20} {"Output Shape":>63} {"Param #":>15} ' \
+                   f'{"inference time[ms]":>25} {"gpu_cached_memory[GB]":>25} {"gpu_occupation[GB]":>25}'
         summary_str += f"{line_new}\n"
         summary_str += f"{line_new}\n"
         summary_str += f"{'=' * 200}\n"
         summary_str += f"{'=' * 200}\n"
     for layer in summary:
     for layer in summary:
@@ -178,7 +179,7 @@ def _convert_summary_dict_to_string(summary: dict, high_verbosity: bool, input_d
                    f"Params size (MB): {total_params_size}\n" \
                    f"Params size (MB): {total_params_size}\n" \
                    f"Estimated Total Size (MB): {total_size}\n"
                    f"Estimated Total Size (MB): {total_size}\n"
 
 
-    summary_str += str(["Memory Footprint (percentage): %0.2f" % gpu_memory_utilization[i] for i in range(4)]) + "\n" \
-                                                                                                                 f"{'-' * 200}\n" if device == 'cuda' else f"{'-' * 200}\n"
+    summary_str += str(["Memory Footprint (percentage): %0.2f" % gpu_memory_utilization[i] for i in range(4)]) + "\n"
+    summary_str += f"{'-' * 200}\n" if device == 'cuda' else f"{'-' * 200}\n"
 
 
     return summary_str
     return summary_str
Discard
@@ -83,12 +83,13 @@ def calibrate_model(model: torch.nn.Module, calib_data_loader: torch.utils.data.
     """
     """
     Calibrates torch model with quantized modules.
     Calibrates torch model with quantized modules.
 
 
-    :param model: torch.nn.Module, model to perfrom the calibration on.
-    :param calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset.
-    :param method: str, One of [percentile, mse, entropy, max]. Statistics method for amax
-                 computation of the quantized modules (default=percentile).
-    :param num_calib_batches: int, number of batches to collect the statistics from.
-    :param percentile: float, percentile value to use when SgModel,quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99).
+    :param model:               torch.nn.Module, model to perfrom the calibration on.
+    :param calib_data_loader:   torch.utils.data.DataLoader, data loader of the calibration dataset.
+    :param method:              str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules
+                                (Default=percentile).
+    :param num_calib_batches:   int, number of batches to collect the statistics from.
+    :param percentile:          float, percentile value to use when SgModel,quant_modules_calib_method='percentile'. Discarded when other methods are used
+                                (Default=99.99).
 
 
     """
     """
     if _imported_pytorch_quantization_failure is not None:
     if _imported_pytorch_quantization_failure is not None:
Discard
@@ -4,7 +4,6 @@ from typing import List
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
-from torch.nn import functional as F
 
 
 from super_gradients.training.utils.detection_utils import non_max_suppression, NMS_Type, \
 from super_gradients.training.utils.detection_utils import non_max_suppression, NMS_Type, \
     matrix_non_max_suppression, DetectionPostPredictionCallback
     matrix_non_max_suppression, DetectionPostPredictionCallback
Discard