Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#835 Feature/sg 772 predict using validation

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-772_predict_using_validation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  1. from typing import Union, Mapping
  2. from super_gradients.common.factories.base_factory import BaseFactory
  3. from super_gradients.common.factories.list_factory import ListFactory
  4. from super_gradients.common.registry.registry import PROCESSINGS
  5. class ProcessingFactory(BaseFactory):
  6. def __init__(self):
  7. super().__init__(PROCESSINGS)
  8. def get(self, conf: Union[str, dict]):
  9. if isinstance(conf, Mapping) and "ComposeProcessing" in conf:
  10. conf["ComposeProcessing"]["processings"] = ListFactory(ProcessingFactory()).get(conf["ComposeProcessing"]["processings"])
  11. return super().get(conf)
Discard
@@ -398,3 +398,15 @@ class Datasets:
     SUPERVISELY_PERSONS_DATASET = "SuperviselyPersonsDataset"
     PASCAL_VOC_AND_AUG_UNIFIED_DATASET = "PascalVOCAndAUGUnifiedDataset"
     COCO_KEY_POINTS_DATASET = "COCOKeypointsDataset"
+
+
+class Processings:
+    StandardizeImage = "StandardizeImage"
+    DetectionCenterPadding = "DetectionCenterPadding"
+    DetectionLongestMaxSizeRescale = "DetectionLongestMaxSizeRescale"
+    DetectionBottomRightPadding = "DetectionBottomRightPadding"
+    ImagePermute = "ImagePermute"
+    DetectionRescale = "DetectionRescale"
+    ReverseImageChannels = "ReverseImageChannels"
+    NormalizeImage = "NormalizeImage"
+    ComposeProcessing = "ComposeProcessing"
Discard
@@ -145,3 +145,6 @@ OPTIMIZERS = {
     Optimizers.RMS_PROP: optim.RMSprop,
 }
 register_optimizer = create_register_decorator(registry=OPTIMIZERS)
+
+PROCESSINGS = {}
+register_processing = create_register_decorator(registry=PROCESSINGS)
Discard
1
2
3
  1. from .module_interfaces import HasPredict, HasPreprocessingParams
  2. __all__ = ["HasPredict", "HasPreprocessingParams"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
  1. from typing_extensions import Protocol, runtime_checkable
  2. @runtime_checkable
  3. class HasPreprocessingParams(Protocol):
  4. """
  5. Protocol interface for torch datasets that support getting preprocessing params, later to be passed to a model
  6. that obeys NeedsPreprocessingParams. This interface class serves a purpose of explicitly indicating whether a torch dataset has
  7. get_dataset_preprocessing_params implemented.
  8. """
  9. def get_dataset_preprocessing_params(self):
  10. ...
  11. @runtime_checkable
  12. class HasPredict(Protocol):
  13. """
  14. Protocol class serves a purpose of explicitly indicating whether a torch model has the functionality of ".predict"
  15. as defined in SG.
  16. """
  17. def set_dataset_processing_params(self, *args, **kwargs):
  18. """Set the processing parameters for the dataset."""
  19. ...
  20. def predict(self, images, *args, **kwargs):
  21. ...
  22. def predict_webcam(self, *args, **kwargs):
  23. ...
Discard
@@ -69,7 +69,7 @@ val_dataset_params:
         pad_value: 114
     - DetectionStandardizeImage:
         max_value: 255.
-    - DetectionImagePermute:
+    - DetectionImagePermute
     - DetectionTargetsFormatTransform:
         max_targets: 50
         input_dim: [640, 640]
Discard
@@ -13,7 +13,7 @@ import numpy as np
 from tqdm import tqdm
 from torch.utils.data import Dataset
 
-from super_gradients.common.object_names import Datasets
+from super_gradients.common.object_names import Datasets, Processings
 from super_gradients.common.registry.registry import register_dataset
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.training.utils.detection_utils import get_cls_posx_in_target
@@ -473,3 +473,22 @@ class DetectionDataset(Dataset):
             plot_counter += 1
             if plot_counter == n_plots:
                 return
+
+    def get_dataset_preprocessing_params(self):
+        """
+        Return any hardcoded preprocessing + adaptation for PIL.Image image reading (RGB).
+         image_processor as returned as as list of dicts to be resolved by processing factory.
+        :return:
+        """
+        pipeline = [Processings.ReverseImageChannels]
+        if self.input_dim is not None:
+            pipeline += [{Processings.DetectionLongestMaxSizeRescale: {"output_shape": self.input_dim}}]
+        for t in self.transforms:
+            pipeline += t.get_equivalent_preprocessing()
+        params = dict(
+            class_names=self.classes,
+            image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
+            iou=0.65,
+            conf=0.5,
+        )
+        return params
Discard
@@ -3,7 +3,7 @@ from typing import Tuple, List, Mapping, Any, Dict
 
 import numpy as np
 import torch
-from torch.utils.data import default_collate, Dataset
+from torch.utils.data.dataloader import default_collate, Dataset
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.registry.registry import register_collate_function
Discard
@@ -10,12 +10,14 @@ from typing import Union, Optional, List
 from torch import nn
 from omegaconf import DictConfig
 
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories.processing_factory import ProcessingFactory
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.models.sg_module import SgModule
 import super_gradients.common.factories.detection_modules_factory as det_factory
 from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
 from super_gradients.training.pipelines.pipelines import DetectionPipeline
-from super_gradients.training.transforms.processing import Processing
+from super_gradients.training.processing.processing import Processing
 from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.media.image import ImageSource
 
@@ -110,6 +112,7 @@ class CustomizableDetector(SgModule):
     def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
         raise NotImplementedError
 
+    @resolve_param("image_processor", ProcessingFactory())
     def set_dataset_processing_params(
         self,
         class_names: Optional[List[str]] = None,
Discard
@@ -1,7 +1,8 @@
 from typing import Union, Optional, List
 
 from torch import Tensor
-
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories.processing_factory import ProcessingFactory
 from super_gradients.common.registry.registry import register_model
 from super_gradients.common.object_names import Models
 from super_gradients.modules import RepVGGBlock
@@ -14,7 +15,7 @@ from super_gradients.training.models.arch_params_factory import get_arch_params
 from super_gradients.training.models.detection_models.pp_yolo_e.post_prediction_callback import PPYoloEPostPredictionCallback, DetectionPostPredictionCallback
 from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
 from super_gradients.training.pipelines.pipelines import DetectionPipeline
-from super_gradients.training.transforms.processing import Processing
+from super_gradients.training.processing.processing import Processing
 from super_gradients.training.utils.media.image import ImageSource
 
 
@@ -37,6 +38,7 @@ class PPYoloE(SgModule):
     def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
         return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
 
+    @resolve_param("image_processor", ProcessingFactory())
     def set_dataset_processing_params(
         self,
         class_names: Optional[List[str]] = None,
Discard
@@ -4,6 +4,8 @@ from typing import Union, Type, List, Tuple, Optional
 import torch
 import torch.nn as nn
 
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories.processing_factory import ProcessingFactory
 from super_gradients.modules import CrossModelSkipConnection
 from super_gradients.training.models.classification_models.regnet import AnyNetX, Stage
 from super_gradients.training.models.detection_models.csp_darknet53 import Conv, GroupedConvBlock, CSPDarknet53, get_yolo_type_params, SPP
@@ -13,7 +15,7 @@ from super_gradients.training.utils.detection_utils import non_max_suppression,
 from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param
 from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
 from super_gradients.training.pipelines.pipelines import DetectionPipeline
-from super_gradients.training.transforms.processing import Processing
+from super_gradients.training.processing.processing import Processing
 from super_gradients.training.utils.media.image import ImageSource
 
 COCO_DETECTION_80_CLASSES_BBOX_ANCHORS = Anchors(
@@ -425,6 +427,7 @@ class YoloBase(SgModule):
     def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
         return YoloPostPredictionCallback(conf=conf, iou=iou)
 
+    @resolve_param("image_processor", ProcessingFactory())
     def set_dataset_processing_params(
         self,
         class_names: Optional[List[str]] = None,
Discard
@@ -6,6 +6,7 @@ import torch
 
 from super_gradients.common.data_types.enum.strict_load import StrictLoad
 from super_gradients.common.plugins.deci_client import DeciClient, client_enabled
+from super_gradients.module_interfaces import HasPredict
 from super_gradients.training import utils as core_utils
 from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
 from super_gradients.training.models import SgModule
@@ -20,7 +21,7 @@ from super_gradients.training.utils.checkpoint_utils import (
 )
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
-from super_gradients.training.transforms.processing import get_pretrained_processing_params
+from super_gradients.training.processing.processing import get_pretrained_processing_params
 
 logger = get_logger(__name__)
 
@@ -136,9 +137,10 @@ def instantiate_model(
                 net.replace_head(new_num_classes=num_classes_new_head)
                 arch_params.num_classes = num_classes_new_head
 
-            # TODO: remove once we load it from the checkpoint
-            processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
-            net.set_dataset_processing_params(**processing_params)
+            # STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE
+            if isinstance(net, HasPredict):
+                processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
+                net.set_dataset_processing_params(**processing_params)
 
     _add_model_name_attribute(net, model_name)
 
@@ -200,7 +202,9 @@ def get(
         raise ValueError("Please set checkpoint_path when load_backbone=True")
 
     if checkpoint_path:
-        load_ema_as_net = "ema_net" in read_ckpt_state_dict(ckpt_path=checkpoint_path).keys()
+        ckpt_entries = read_ckpt_state_dict(ckpt_path=checkpoint_path).keys()
+        load_processing = "processing_params" in ckpt_entries
+        load_ema_as_net = "ema_net" in ckpt_entries
         _ = load_checkpoint_to_model(
             ckpt_local_path=checkpoint_path,
             load_backbone=load_backbone,
@@ -208,6 +212,7 @@ def get(
             strict=strict_load.value if hasattr(strict_load, "value") else strict_load,
             load_weights_only=True,
             load_ema_as_net=load_ema_as_net,
+            load_processing_params=load_processing,
         )
     if checkpoint_num_classes != num_classes:
         net.replace_head(new_num_classes=num_classes)
Discard
@@ -3,7 +3,6 @@ from typing import Union
 from torch import nn
 
 from super_gradients.training.utils.utils import HpmStruct
-from super_gradients.training.models.prediction_results import ImagesPredictions
 
 
 class SgModule(nn.Module):
@@ -63,13 +62,3 @@ class SgModule(nn.Module):
         """
 
         raise NotImplementedError
-
-    def predict(self, images, *args, **kwargs) -> ImagesPredictions:
-        raise NotImplementedError(f"`predict` is not implemented for {self.__class__.__name__}.")
-
-    def predict_webcam(self, *args, **kwargs) -> None:
-        raise NotImplementedError(f"`predict_webcam` is not implemented for {self.__class__.__name__}.")
-
-    def set_dataset_processing_params(self, *args, **kwargs) -> None:
-        """Set the processing parameters for the dataset."""
-        pass
Discard
@@ -20,7 +20,7 @@ from super_gradients.training.models.prediction_results import (
     VideoPredictions,
 )
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
-from super_gradients.training.transforms.processing import Processing, ComposeProcessing
+from super_gradients.training.processing.processing import Processing, ComposeProcessing
 from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
  1. from .processing import (
  2. StandardizeImage,
  3. DetectionRescale,
  4. DetectionLongestMaxSizeRescale,
  5. DetectionBottomRightPadding,
  6. DetectionCenterPadding,
  7. ImagePermute,
  8. ReverseImageChannels,
  9. NormalizeImage,
  10. ComposeProcessing,
  11. )
  12. __all__ = [
  13. "StandardizeImage",
  14. "DetectionRescale",
  15. "DetectionLongestMaxSizeRescale",
  16. "DetectionBottomRightPadding",
  17. "DetectionCenterPadding",
  18. "ImagePermute",
  19. "ReverseImageChannels",
  20. "NormalizeImage",
  21. "ComposeProcessing",
  22. ]
Discard
@@ -4,6 +4,7 @@ from dataclasses import dataclass
 
 import numpy as np
 
+from super_gradients.common.registry.registry import register_processing
 from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
 from super_gradients.training.transforms.utils import (
@@ -15,6 +16,7 @@ from super_gradients.training.transforms.utils import (
     _shift_bboxes,
     PaddingCoordinates,
 )
+from super_gradients.common.object_names import Processings
 
 
 @dataclass
@@ -58,6 +60,7 @@ class Processing(ABC):
         pass
 
 
+@register_processing(Processings.ComposeProcessing)
 class ComposeProcessing(Processing):
     """Compose a list of Processing objects into a single Processing object."""
 
@@ -80,6 +83,7 @@ class ComposeProcessing(Processing):
         return postprocessed_predictions
 
 
+@register_processing(Processings.ImagePermute)
 class ImagePermute(Processing):
     """Permute the image dimensions.
 
@@ -97,6 +101,7 @@ class ImagePermute(Processing):
         return predictions
 
 
+@register_processing(Processings.ReverseImageChannels)
 class ReverseImageChannels(Processing):
     """Reverse the order of the image channels (RGB -> BGR or BGR -> RGB)."""
 
@@ -117,6 +122,7 @@ class ReverseImageChannels(Processing):
         return predictions
 
 
+@register_processing(Processings.StandardizeImage)
 class StandardizeImage(Processing):
     """Standardize image pixel values with img/max_val
 
@@ -140,6 +146,7 @@ class StandardizeImage(Processing):
         return predictions
 
 
+@register_processing(Processings.NormalizeImage)
 class NormalizeImage(Processing):
     """Normalize an image based on means and standard deviation.
 
@@ -189,11 +196,13 @@ class _DetectionPadding(Processing, ABC):
         pass
 
 
+@register_processing(Processings.DetectionCenterPadding)
 class DetectionCenterPadding(_DetectionPadding):
     def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
         return _get_center_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape)
 
 
+@register_processing(Processings.DetectionBottomRightPadding)
 class DetectionBottomRightPadding(_DetectionPadding):
     def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
         return _get_bottom_right_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape)
@@ -236,12 +245,14 @@ class _LongestMaxSizeRescale(Processing, ABC):
         return image, RescaleMetadata(original_shape=(height, width), scale_factor_h=scale_factor, scale_factor_w=scale_factor)
 
 
+@register_processing(Processings.DetectionRescale)
 class DetectionRescale(_Rescale):
     def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction:
         predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))
         return predictions
 
 
+@register_processing(Processings.DetectionLongestMaxSizeRescale)
 class DetectionLongestMaxSizeRescale(_LongestMaxSizeRescale):
     def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction:
         predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))
Discard
@@ -18,6 +18,7 @@ from torchmetrics import MetricCollection
 from tqdm import tqdm
 
 from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path, get_ckpt_local_path
+from super_gradients.module_interfaces import HasPreprocessingParams, HasPredict
 
 from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
 from super_gradients.common.abstractions.abstract_logger import get_logger
@@ -598,6 +599,10 @@ class Trainer:
 
         if self.ema:
             state["ema_net"] = self.ema_model.ema.state_dict()
+
+        if isinstance(self.net.module, HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
+            state["processing_params"] = self.valid_loader.dataset.get_dataset_preprocessing_params()
+
         # SAVES CURRENT MODEL AS ckpt_latest
         self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)
 
@@ -1217,6 +1222,7 @@ class Trainer:
             train_dataloader_len=len(self.train_loader),
         )
 
+        self._set_net_preprocessing_from_valid_loader()
         try:
             # HEADERS OF THE TRAINING PROGRESS
             if not silent_mode:
@@ -1314,6 +1320,16 @@ class Trainer:
             if not self.ddp_silent_mode:
                 self.sg_logger.close()
 
+    def _set_net_preprocessing_from_valid_loader(self):
+        if isinstance(self.net.module, HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
+            try:
+                self.net.module.set_dataset_processing_params(**self.valid_loader.dataset.get_dataset_preprocessing_params())
+            except Exception as e:
+                logger.warning(
+                    f"Could not set preprocessing pipeline from the validation dataset:\n {e}.\n Before calling"
+                    "predict make sure to call set_dataset_processing_params."
+                )
+
     def _reset_best_metric(self):
         self.best_metric = -1 * np.inf if self.greater_metric_to_watch_is_better else np.inf
 
Discard
@@ -11,7 +11,7 @@ from PIL import Image, ImageFilter, ImageOps
 from torchvision import transforms as transforms
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.common.object_names import Transforms
+from super_gradients.common.object_names import Transforms, Processings
 from super_gradients.common.registry.registry import register_transform
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.data_formats_factory import ConcatenatedTensorFormatFactory
@@ -424,6 +424,9 @@ class DetectionTransform:
     def __repr__(self):
         return self.__class__.__name__ + str(self.__dict__).replace("{", "(").replace("}", ")")
 
+    def get_equivalent_preprocessing(self) -> List:
+        raise NotImplementedError
+
 
 @register_transform(Transforms.DetectionStandardize)
 class DetectionStandardize(DetectionTransform):
@@ -441,6 +444,9 @@ class DetectionStandardize(DetectionTransform):
         sample["image"] = (sample["image"] / self.max_value).astype(np.float32)
         return sample
 
+    def get_equivalent_preprocessing(self) -> List[Dict]:
+        return [{Processings.StandardizeImage: {"max_value": self.max_value}}]
+
 
 @register_transform(Transforms.DetectionMosaic)
 class DetectionMosaic(DetectionTransform):
@@ -717,6 +723,9 @@ class DetectionImagePermute(DetectionTransform):
         sample["image"] = np.ascontiguousarray(sample["image"].transpose(*self.dims))
         return sample
 
+    def get_equivalent_preprocessing(self) -> List[Dict]:
+        return [{Processings.ImagePermute: {"permutation": self.dims}}]
+
 
 @register_transform(Transforms.DetectionPadToSize)
 class DetectionPadToSize(DetectionTransform):
@@ -748,6 +757,9 @@ class DetectionPadToSize(DetectionTransform):
             sample["crowd_target"] = _shift_bboxes(targets=crowd_targets, shift_w=padding_coordinates.left, shift_h=padding_coordinates.top)
         return sample
 
+    def get_equivalent_preprocessing(self) -> List:
+        return [{Processings.DetectionCenterPadding: {"output_shape": self.output_size, "pad_value": self.pad_value}}]
+
 
 @register_transform(Transforms.DetectionPaddedRescale)
 class DetectionPaddedRescale(DetectionTransform):
@@ -779,6 +791,13 @@ class DetectionPaddedRescale(DetectionTransform):
             sample["crowd_target"] = _rescale_xyxy_bboxes(crowd_targets, r)
         return sample
 
+    def get_equivalent_preprocessing(self) -> List[Dict]:
+        return [
+            {Processings.DetectionLongestMaxSizeRescale: {"output_shape": self.input_dim}},
+            {Processings.DetectionBottomRightPadding: {"output_shape": self.input_dim, "pad_value": self.pad_value}},
+            {Processings.ImagePermute: {"permutation": self.swap}},
+        ]
+
 
 @register_transform(Transforms.DetectionHorizontalFlip)
 class DetectionHorizontalFlip(DetectionTransform):
@@ -830,6 +849,9 @@ class DetectionRescale(DetectionTransform):
             sample["crowd_target"] = _rescale_bboxes(crowd_targets, scale_factors=(sy, sx))
         return sample
 
+    def get_equivalent_preprocessing(self) -> List[Dict]:
+        return [{Processings.DetectionRescale: {"output_shape": self.output_size}}]
+
 
 @register_transform(Transforms.DetectionRandomRotate90)
 class DetectionRandomRotate90(DetectionTransform):
@@ -907,6 +929,11 @@ class DetectionRGB2BGR(DetectionTransform):
             sample["image"] = sample["image"][..., ::-1]
         return sample
 
+    def get_equivalent_preprocessing(self) -> List:
+        if self.prob < 1:
+            raise RuntimeError("Cannot set preprocessing pipeline with randomness. Set prob to 1.")
+        return [{Processings.ReverseImageChannels}]
+
 
 @register_transform(Transforms.DetectionHSV)
 class DetectionHSV(DetectionTransform):
@@ -961,6 +988,9 @@ class DetectionNormalize(DetectionTransform):
         sample["image"] = (sample["image"] - self.mean) / self.std
         return sample
 
+    def get_equivalent_preprocessing(self) -> List[Dict]:
+        return [{Processings.NormalizeImage: {"mean": self.mean, "std": self.std}}]
+
 
 @register_transform(Transforms.DetectionTargetsFormatTransform)
 class DetectionTargetsFormatTransform(DetectionTransform):
@@ -1045,6 +1075,9 @@ class DetectionTargetsFormatTransform(DetectionTransform):
         padded_targets = np.ascontiguousarray(padded_targets, dtype=np.float32)
         return padded_targets
 
+    def get_equivalent_preprocessing(self) -> List:
+        return []
+
 
 def get_aug_params(value: Union[tuple, float], center: float = 0) -> float:
     """
Discard
@@ -7,6 +7,7 @@ import torch
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
 from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
+from super_gradients.module_interfaces import HasPredict
 from super_gradients.training.pretrained_models import MODEL_URLS
 from super_gradients.common.data_types import StrictLoad
 
@@ -186,16 +187,20 @@ def load_checkpoint_to_model(
     strict: Union[str, StrictLoad] = StrictLoad.NO_KEY_MATCHING,
     load_weights_only: bool = False,
     load_ema_as_net: bool = False,
+    load_processing_params: bool = False,
 ):
     """
     Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
 
+
     :param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set
     :param ckpt_local_path: local path to the checkpoint file
     :param load_backbone: whether to load the checkpoint as a backbone
     :param net: network to load the checkpoint to
     :param strict:
-    :param load_weights_only:
+    :param load_weights_only: Whether to ignore all other entries other then "net".
+    :param load_processing_params: Whether to call set_dataset_processing_params on "processing_params" entry inside the
+     checkpoint file (default=False).
     :return:
     """
     if isinstance(strict, str):
@@ -227,6 +232,17 @@ def load_checkpoint_to_model(
     message_model = "model" if not load_backbone else "model's backbone"
     logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)
 
+    if (isinstance(net, HasPredict) or (hasattr(net, "module") and isinstance(net.module, HasPredict))) and load_processing_params:
+        if "processing_params" not in checkpoint.keys():
+            raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
+        try:
+            net.set_dataset_processing_params(**checkpoint["processing_params"])
+        except Exception as e:
+            logger.warning(
+                f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
+                "predict make sure to call set_dataset_processing_params."
+            )
+
     if load_weights_only or load_backbone:
         # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
         [checkpoint.pop(key) for key in list(checkpoint.keys()) if key != "net"]
Discard
@@ -30,6 +30,7 @@ from tests.unit_tests.load_checkpoint_test import LoadCheckpointTest
 from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
 from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
+from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest
 from tests.unit_tests.quantization_utility_tests import QuantizationUtilityTest
 from tests.unit_tests.random_erase_test import RandomEraseTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
@@ -129,6 +130,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DEKRLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationMetrics))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LoadCheckpointTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PreprocessingUnitTest))
 
     def _add_modules_to_end_to_end_tests_suite(self):
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
  1. import os
  2. import unittest
  3. from pathlib import Path
  4. from super_gradients import Trainer
  5. from super_gradients.training import models
  6. from super_gradients.training.datasets import COCODetectionDataset
  7. from super_gradients.training.metrics import DetectionMetrics
  8. from super_gradients.training.models import YoloPostPredictionCallback
  9. from super_gradients.training.processing import ReverseImageChannels, DetectionLongestMaxSizeRescale, DetectionBottomRightPadding, ImagePermute
  10. from super_gradients.training.utils.detection_utils import DetectionCollateFN
  11. from super_gradients.training import dataloaders
  12. class PreprocessingUnitTest(unittest.TestCase):
  13. def setUp(self) -> None:
  14. self.mini_coco_data_dir = str(Path(__file__).parent.parent / "data" / "tinycoco")
  15. def test_getting_preprocessing_params(self):
  16. expected_image_processor = {
  17. "ComposeProcessing": {
  18. "processings": [
  19. "ReverseImageChannels",
  20. {"DetectionLongestMaxSizeRescale": {"output_shape": [512, 512]}},
  21. {"DetectionLongestMaxSizeRescale": {"output_shape": [512, 512]}},
  22. {"DetectionBottomRightPadding": {"output_shape": [512, 512], "pad_value": 114}},
  23. {"ImagePermute": {"permutation": (2, 0, 1)}},
  24. ]
  25. }
  26. }
  27. train_dataset_params = {
  28. "data_dir": self.mini_coco_data_dir,
  29. "subdir": "images/train2017",
  30. "json_file": "instances_train2017.json",
  31. "cache": False,
  32. "input_dim": [512, 512],
  33. "transforms": [
  34. {"DetectionPaddedRescale": {"input_dim": [512, 512]}},
  35. {"DetectionTargetsFormatTransform": {"max_targets": 50, "input_dim": [512, 512], "output_format": "LABEL_CXCYWH"}},
  36. ],
  37. }
  38. dataset = COCODetectionDataset(**train_dataset_params)
  39. preprocessing_params = dataset.get_dataset_preprocessing_params()
  40. self.assertEqual(len(preprocessing_params["class_names"]), 80)
  41. self.assertEqual(preprocessing_params["image_processor"], expected_image_processor)
  42. self.assertEqual(preprocessing_params["iou"], 0.65)
  43. self.assertEqual(preprocessing_params["conf"], 0.5)
  44. def test_setting_preprocessing_params_from_validation_set(self):
  45. train_dataset_params = {
  46. "data_dir": self.mini_coco_data_dir,
  47. "subdir": "images/train2017",
  48. "json_file": "instances_train2017.json",
  49. "cache": False,
  50. "input_dim": [329, 320],
  51. "transforms": [
  52. {"DetectionPaddedRescale": {"input_dim": [512, 512]}},
  53. {"DetectionTargetsFormatTransform": {"max_targets": 50, "input_dim": [512, 512], "output_format": "LABEL_CXCYWH"}},
  54. ],
  55. }
  56. val_dataset_params = {
  57. "data_dir": self.mini_coco_data_dir,
  58. "subdir": "images/val2017",
  59. "json_file": "instances_val2017.json",
  60. "cache": False,
  61. "input_dim": [329, 320],
  62. "transforms": [
  63. {"DetectionPaddedRescale": {"input_dim": [512, 512]}},
  64. {"DetectionTargetsFormatTransform": {"max_targets": 50, "input_dim": [512, 512], "output_format": "LABEL_CXCYWH"}},
  65. ],
  66. }
  67. trainset = COCODetectionDataset(**train_dataset_params)
  68. train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN()})
  69. valset = COCODetectionDataset(**val_dataset_params)
  70. valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": DetectionCollateFN()})
  71. trainer = Trainer("test_setting_preprocessing_params_from_validation_set")
  72. detection_train_params_yolox = {
  73. "max_epochs": 1,
  74. "lr_mode": "cosine",
  75. "cosine_final_lr_ratio": 0.05,
  76. "warmup_bias_lr": 0.0,
  77. "warmup_momentum": 0.9,
  78. "initial_lr": 0.02,
  79. "loss": "yolox_loss",
  80. "criterion_params": {"strides": [8, 16, 32], "num_classes": 80}, # output strides of all yolo outputs
  81. "train_metrics_list": [],
  82. "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True, num_cls=5)],
  83. "metric_to_watch": "mAP@0.50:0.95",
  84. "greater_metric_to_watch_is_better": True,
  85. "average_best_models": False,
  86. }
  87. model = models.get("yolox_s", num_classes=80)
  88. trainer.train(model=model, training_params=detection_train_params_yolox, train_loader=train_loader, valid_loader=valid_loader)
  89. processing_list = model._image_processor.processings
  90. self.assertTrue(isinstance(processing_list[0], ReverseImageChannels))
  91. self.assertTrue(isinstance(processing_list[1], DetectionLongestMaxSizeRescale))
  92. self.assertTrue(isinstance(processing_list[2], DetectionLongestMaxSizeRescale))
  93. self.assertTrue(isinstance(processing_list[3], DetectionBottomRightPadding))
  94. self.assertTrue(isinstance(processing_list[4], ImagePermute))
  95. self.assertTrue(len(processing_list), 5)
  96. self.assertEqual(model._default_nms_iou, 0.65)
  97. self.assertEqual(model._default_nms_conf, 0.5)
  98. def test_setting_preprocessing_params_from_checkpoint(self):
  99. model = models.get("yolox_s", num_classes=80)
  100. self.assertTrue(model._image_processor is None)
  101. self.assertTrue(model._default_nms_iou is None)
  102. self.assertTrue(model._default_nms_conf is None)
  103. self.assertTrue(model._class_names is None)
  104. train_dataset_params = {
  105. "data_dir": self.mini_coco_data_dir,
  106. "subdir": "images/train2017",
  107. "json_file": "instances_train2017.json",
  108. "cache": False,
  109. "input_dim": [329, 320],
  110. "transforms": [
  111. {"DetectionPaddedRescale": {"input_dim": [512, 512]}},
  112. {"DetectionTargetsFormatTransform": {"max_targets": 50, "input_dim": [512, 512], "output_format": "LABEL_CXCYWH"}},
  113. ],
  114. }
  115. val_dataset_params = {
  116. "data_dir": self.mini_coco_data_dir,
  117. "subdir": "images/val2017",
  118. "json_file": "instances_val2017.json",
  119. "cache": False,
  120. "input_dim": [329, 320],
  121. "transforms": [
  122. {"DetectionPaddedRescale": {"input_dim": [512, 512]}},
  123. {"DetectionTargetsFormatTransform": {"max_targets": 50, "input_dim": [512, 512], "output_format": "LABEL_CXCYWH"}},
  124. ],
  125. }
  126. trainset = COCODetectionDataset(**train_dataset_params)
  127. train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN()})
  128. valset = COCODetectionDataset(**val_dataset_params)
  129. valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": DetectionCollateFN()})
  130. trainer = Trainer("save_ckpt_for")
  131. detection_train_params_yolox = {
  132. "max_epochs": 1,
  133. "lr_mode": "cosine",
  134. "cosine_final_lr_ratio": 0.05,
  135. "warmup_bias_lr": 0.0,
  136. "warmup_momentum": 0.9,
  137. "initial_lr": 0.02,
  138. "loss": "yolox_loss",
  139. "criterion_params": {"strides": [8, 16, 32], "num_classes": 80}, # output strides of all yolo outputs
  140. "train_metrics_list": [],
  141. "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True, num_cls=5)],
  142. "metric_to_watch": "mAP@0.50:0.95",
  143. "greater_metric_to_watch_is_better": True,
  144. "average_best_models": False,
  145. }
  146. trainer.train(model=model, training_params=detection_train_params_yolox, train_loader=train_loader, valid_loader=valid_loader)
  147. model = models.get("yolox_s", num_classes=80, checkpoint_path=os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth"))
  148. processing_list = model._image_processor.processings
  149. self.assertTrue(isinstance(processing_list[0], ReverseImageChannels))
  150. self.assertTrue(isinstance(processing_list[1], DetectionLongestMaxSizeRescale))
  151. self.assertTrue(isinstance(processing_list[2], DetectionLongestMaxSizeRescale))
  152. self.assertTrue(isinstance(processing_list[3], DetectionBottomRightPadding))
  153. self.assertTrue(isinstance(processing_list[4], ImagePermute))
  154. self.assertTrue(len(processing_list), 5)
  155. self.assertEqual(model._default_nms_iou, 0.65)
  156. self.assertEqual(model._default_nms_conf, 0.5)
  157. if __name__ == "__main__":
  158. unittest.main()
Discard