Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#807 Feature/sg 747 add full pipeline with preprocessing

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-747-add_full_pipeline_with_preprocessing
1
2
3
4
5
6
7
8
9
10
11
12
13
  1. from super_gradients.common.object_names import Models
  2. from super_gradients.training import models
  3. # Note that currently only YoloX and PPYoloE are supported.
  4. model = models.get(Models.PP_YOLOE_S, pretrained_weights="coco")
  5. IMAGES = [
  6. "https://miro.medium.com/v2/resize:fit:500/0*w1s81z-Q72obhE_z",
  7. "https://s.hs-data.com/bilder/spieler/gross/128069.jpg",
  8. "https://datasets-server.huggingface.co/assets/Chris1/cityscapes/--/Chris1--cityscapes/train/28/image/image.jpg",
  9. ]
  10. prediction = model.predict(IMAGES, iou=0.65, conf=0.5)
  11. prediction.show()
Discard
@@ -1,4 +1,4 @@
-from typing import Union
+from typing import Union, Optional, List
 
 
 from torch import Tensor
 from torch import Tensor
 
 
@@ -11,6 +11,10 @@ from super_gradients.training.models.detection_models.pp_yolo_e.pan import Custo
 from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import PPYOLOEHead
 from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import PPYOLOEHead
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.models.arch_params_factory import get_arch_params
 from super_gradients.training.models.arch_params_factory import get_arch_params
+from super_gradients.training.models.detection_models.pp_yolo_e.post_prediction_callback import PPYoloEPostPredictionCallback, DetectionPostPredictionCallback
+from super_gradients.training.models.results import DetectionResults
+from super_gradients.training.pipelines.pipelines import DetectionPipeline
+from super_gradients.training.transforms.processing import Processing
 
 
 
 
 class PPYoloE(SgModule):
 class PPYoloE(SgModule):
@@ -23,6 +27,37 @@ class PPYoloE(SgModule):
         self.neck = CustomCSPPAN(**arch_params["neck"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
         self.neck = CustomCSPPAN(**arch_params["neck"], depth_mult=arch_params["depth_mult"], width_mult=arch_params["width_mult"])
         self.head = PPYOLOEHead(**arch_params["head"], width_mult=arch_params["width_mult"], num_classes=arch_params["num_classes"])
         self.head = PPYOLOEHead(**arch_params["head"], width_mult=arch_params["width_mult"], num_classes=arch_params["num_classes"])
 
 
+        self._class_names: Optional[List[str]] = None
+        self._image_processor: Optional[Processing] = None
+
+    @staticmethod
+    def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
+        return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
+
+    def set_dataset_processing_params(self, class_names: Optional[List[str]], image_processor: Optional[Processing]) -> None:
+        """Set the processing parameters for the dataset.
+
+        :param class_names:     (Optional) Names of the dataset the model was trained on.
+        :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training.
+        """
+        self._class_names = class_names or self._class_names
+        self._image_processor = image_processor or self._image_processor
+
+    def predict(self, images, iou: float = 0.65, conf: float = 0.01) -> DetectionResults:
+
+        if self._class_names is None or self._image_processor is None:
+            raise RuntimeError(
+                "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
+            )
+
+        pipeline = DetectionPipeline(
+            model=self,
+            image_processor=self._image_processor,
+            post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
+            class_names=self._class_names,
+        )
+        return pipeline(images)
+
     def forward(self, x: Tensor):
     def forward(self, x: Tensor):
         features = self.backbone(x)
         features = self.backbone(x)
         features = self.neck(features)
         features = self.neck(features)
Discard
@@ -1,5 +1,5 @@
 import math
 import math
-from typing import Union, Type, List, Tuple
+from typing import Union, Type, List, Tuple, Optional
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
@@ -11,6 +11,10 @@ from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.utils import torch_version_is_greater_or_equal
 from super_gradients.training.utils import torch_version_is_greater_or_equal
 from super_gradients.training.utils.detection_utils import non_max_suppression, matrix_non_max_suppression, NMS_Type, DetectionPostPredictionCallback, Anchors
 from super_gradients.training.utils.detection_utils import non_max_suppression, matrix_non_max_suppression, NMS_Type, DetectionPostPredictionCallback, Anchors
 from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param
 from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param
+from super_gradients.training.models.results import DetectionResults
+from super_gradients.training.pipelines.pipelines import DetectionPipeline
+from super_gradients.training.transforms.processing import Processing
+
 
 
 COCO_DETECTION_80_CLASSES_BBOX_ANCHORS = Anchors(
 COCO_DETECTION_80_CLASSES_BBOX_ANCHORS = Anchors(
     [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], strides=[8, 16, 32]
     [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], strides=[8, 16, 32]
@@ -80,6 +84,11 @@ class YoloPostPredictionCallback(DetectionPostPredictionCallback):
         self.with_confidence = with_confidence
         self.with_confidence = with_confidence
 
 
     def forward(self, x, device: str = None):
     def forward(self, x, device: str = None):
+        """Apply NMS to the raw output of the model and keep only top `max_predictions` results.
+
+        :param x: Raw output of the model, with x[0] expected to be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)
+        :return: List of Tensors of shape (x1, y1, x2, y2, conf, cls)
+        """
 
 
         if self.nms_type == NMS_Type.ITERATIVE:
         if self.nms_type == NMS_Type.ITERATIVE:
             nms_result = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, with_confidence=self.with_confidence)
             nms_result = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, with_confidence=self.with_confidence)
@@ -90,7 +99,6 @@ class YoloPostPredictionCallback(DetectionPostPredictionCallback):
 
 
     def _filter_max_predictions(self, res: List) -> List:
     def _filter_max_predictions(self, res: List) -> List:
         res[:] = [im[: self.max_pred] if (im is not None and im.shape[0] > self.max_pred) else im for im in res]
         res[:] = [im[: self.max_pred] if (im is not None and im.shape[0] > self.max_pred) else im for im in res]
-
         return res
         return res
 
 
 
 
@@ -408,6 +416,36 @@ class YoloBase(SgModule):
             self._head = YoloHead(self.arch_params)
             self._head = YoloHead(self.arch_params)
             self._initialize_module()
             self._initialize_module()
 
 
+        self._class_names: Optional[List[str]] = None
+        self._image_processor: Optional[Processing] = None
+
+    @staticmethod
+    def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
+        return YoloPostPredictionCallback(conf=conf, iou=iou)
+
+    def set_dataset_processing_params(self, class_names: Optional[List[str]], image_processor: Optional[Processing]) -> None:
+        """Set the processing parameters for the dataset.
+
+        :param class_names:     (Optional) Names of the dataset the model was trained on.
+        :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training.
+        """
+        self._class_names = class_names or self._class_names
+        self._image_processor = image_processor or self._image_processor
+
+    def predict(self, images, iou: float = 0.65, conf: float = 0.01) -> DetectionResults:
+        if self._class_names is None or self._image_processor is None:
+            raise RuntimeError(
+                "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
+            )
+
+        pipeline = DetectionPipeline(
+            model=self,
+            image_processor=self._image_processor,
+            post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
+            class_names=self._class_names,
+        )
+        return pipeline(images)
+
     def forward(self, x):
     def forward(self, x):
         out = self._backbone(x)
         out = self._backbone(x)
         out = self._head(out)
         out = self._head(out)
@@ -429,9 +467,7 @@ class YoloBase(SgModule):
         self._initialize_biases()
         self._initialize_biases()
         self._initialize_weights()
         self._initialize_weights()
         if self.arch_params.add_nms:
         if self.arch_params.add_nms:
-            nms_conf = self.arch_params.nms_conf
-            nms_iou = self.arch_params.nms_iou
-            self._nms = YoloPostPredictionCallback(nms_conf, nms_iou)
+            self._nms = self.get_post_prediction_callback(conf=self.arch_params.nms_conf, iou=self.arch_params.nms_iou)
 
 
     def _check_strides(self):
     def _check_strides(self):
         m = self._head._modules_list[-1]  # DetectX()
         m = self._head._modules_list[-1]  # DetectX()
Discard
@@ -20,6 +20,7 @@ from super_gradients.training.utils.checkpoint_utils import (
 )
 )
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
 from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
+from super_gradients.training.transforms.processing import get_pretrained_processing_params
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -135,6 +136,9 @@ def instantiate_model(
                 net.replace_head(new_num_classes=num_classes_new_head)
                 net.replace_head(new_num_classes=num_classes_new_head)
                 arch_params.num_classes = num_classes_new_head
                 arch_params.num_classes = num_classes_new_head
 
 
+            class_names, image_processor = get_pretrained_processing_params(model_name, pretrained_weights)
+            net.set_dataset_processing_params(class_names, image_processor)
+
     _add_model_name_attribute(net, model_name)
     _add_model_name_attribute(net, model_name)
 
 
     return net
     return net
Discard
@@ -29,13 +29,27 @@ class DetectionPrediction(Prediction):
         :param labels:      Labels for each bounding box.
         :param labels:      Labels for each bounding box.
         :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
         :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
         """
         """
+        self._validate_input(bboxes, confidence, labels)
+
         factory = BBoxFormatFactory()
         factory = BBoxFormatFactory()
-        self.bboxes_xyxy = convert_bboxes(
+        bboxes_xyxy = convert_bboxes(
             bboxes=bboxes,
             bboxes=bboxes,
             image_shape=image_shape,
             image_shape=image_shape,
             source_format=factory.get(bbox_format),
             source_format=factory.get(bbox_format),
             target_format=factory.get("xyxy"),
             target_format=factory.get("xyxy"),
             inplace=False,
             inplace=False,
         )
         )
+
+        self.bboxes_xyxy = bboxes_xyxy
         self.confidence = confidence
         self.confidence = confidence
         self.labels = labels
         self.labels = labels
+
+    def _validate_input(self, bboxes: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None:
+        n_bboxes, n_confidences, n_labels = bboxes.shape[0], confidence.shape[0], labels.shape[0]
+        if n_bboxes != n_confidences != n_labels:
+            raise ValueError(
+                f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})."
+            )
+
+    def __len__(self):
+        return len(self.bboxes_xyxy)
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Tuple
  3. from dataclasses import dataclass
  4. from matplotlib import pyplot as plt
  5. import numpy as np
  6. from super_gradients.training.utils.detection_utils import DetectionVisualization
  7. from super_gradients.training.models.predictions import Prediction, DetectionPrediction
  8. @dataclass
  9. class Result(ABC):
  10. """Results of a given computer vision task (detection, classification, etc.).
  11. :attr image: Input image
  12. :attr predictions: Predictions of the model
  13. :attr class_names: List of the class names to predict
  14. """
  15. image: np.ndarray
  16. predictions: Prediction
  17. class_names: List[str]
  18. @abstractmethod
  19. def draw(self) -> np.ndarray:
  20. """Draw the predictions on the image."""
  21. pass
  22. @abstractmethod
  23. def show(self) -> None:
  24. """Display the predictions on the image."""
  25. pass
  26. @dataclass
  27. class Results(ABC):
  28. """List of results of a given computer vision task (detection, classification, etc.).
  29. :attr results: List of results of the run
  30. """
  31. results: List[Result]
  32. @abstractmethod
  33. def draw(self) -> List[np.ndarray]:
  34. """Draw the predictions on the image."""
  35. pass
  36. @abstractmethod
  37. def show(self) -> None:
  38. """Display the predictions on the image."""
  39. pass
  40. @dataclass
  41. class DetectionResult(Result):
  42. """Result of a detection task.
  43. :attr image: Input image
  44. :attr predictions: Predictions of the model
  45. :attr class_names: List of the class names to predict
  46. """
  47. image: np.ndarray
  48. predictions: DetectionPrediction
  49. class_names: List[str]
  50. def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> np.ndarray:
  51. """Draw the predicted bboxes on the image.
  52. :param box_thickness: Thickness of bounding boxes.
  53. :param show_confidence: Whether to show confidence scores on the image.
  54. :param color_mapping: List of tuples representing the colors for each class.
  55. Default is None, which generates a default color mapping based on the number of class names.
  56. :return: Image with predicted bboxes. Note that this does not modify the original image.
  57. """
  58. image_np = self.image.copy()
  59. color_mapping = color_mapping or DetectionVisualization._generate_color_mapping(len(self.class_names))
  60. for pred_i in range(len(self.predictions)):
  61. image_np = DetectionVisualization._draw_box_title(
  62. color_mapping=color_mapping,
  63. class_names=self.class_names,
  64. box_thickness=box_thickness,
  65. image_np=image_np,
  66. x1=int(self.predictions.bboxes_xyxy[pred_i, 0]),
  67. y1=int(self.predictions.bboxes_xyxy[pred_i, 1]),
  68. x2=int(self.predictions.bboxes_xyxy[pred_i, 2]),
  69. y2=int(self.predictions.bboxes_xyxy[pred_i, 3]),
  70. class_id=int(self.predictions.labels[pred_i]),
  71. pred_conf=self.predictions.confidence[pred_i] if show_confidence else None,
  72. )
  73. return image_np
  74. def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> None:
  75. """Display the image with predicted bboxes.
  76. :param box_thickness: Thickness of bounding boxes.
  77. :param show_confidence: Whether to show confidence scores on the image.
  78. :param color_mapping: List of tuples representing the colors for each class.
  79. Default is None, which generates a default color mapping based on the number of class names.
  80. """
  81. image_np = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
  82. plt.imshow(image_np, interpolation="nearest")
  83. plt.axis("off")
  84. plt.show()
  85. @dataclass
  86. class DetectionResults(Results):
  87. """Results of a detection task.
  88. :attr results: List of the predictions results
  89. """
  90. def __init__(self, images: List[np.ndarray], predictions: List[DetectionPrediction], class_names: List[str]):
  91. self.results: List[DetectionResult] = []
  92. for image, prediction in zip(images, predictions):
  93. self.results.append(DetectionResult(image=image, predictions=prediction, class_names=class_names))
  94. def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> List[np.ndarray]:
  95. """Draw the predicted bboxes on the images.
  96. :param box_thickness: Thickness of bounding boxes.
  97. :param show_confidence: Whether to show confidence scores on the image.
  98. :param color_mapping: List of tuples representing the colors for each class.
  99. Default is None, which generates a default color mapping based on the number of class names.
  100. :return: List of Images with predicted bboxes for each image. Note that this does not modify the original images.
  101. """
  102. return [prediction.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for prediction in self.results]
  103. def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> None:
  104. """Display the predicted bboxes on the images.
  105. :param box_thickness: Thickness of bounding boxes.
  106. :param show_confidence: Whether to show confidence scores on the image.
  107. :param color_mapping: List of tuples representing the colors for each class.
  108. Default is None, which generates a default color mapping based on the number of class names.
  109. """
  110. for prediction in self.results:
  111. prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
Discard
@@ -3,6 +3,7 @@ from typing import Union
 from torch import nn
 from torch import nn
 
 
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
+from super_gradients.training.models.results import Result
 
 
 
 
 class SgModule(nn.Module):
 class SgModule(nn.Module):
@@ -62,3 +63,10 @@ class SgModule(nn.Module):
         """
         """
 
 
         raise NotImplementedError
         raise NotImplementedError
+
+    def predict(self, images, *args, **kwargs) -> Result:
+        raise NotImplementedError(f"`predict` is not implemented for {self.__class__.__name__}.")
+
+    def set_dataset_processing_params(self, *args, **kwargs) -> None:
+        """Set the processing parameters for the dataset."""
+        pass
Discard
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    1. from abc import ABC, abstractmethod
    2. from typing import List, Optional, Tuple, Union
    3. from contextlib import contextmanager
    4. import numpy as np
    5. import torch
    6. from super_gradients.training.utils.load_image import load_images, ImageType
    7. from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
    8. from super_gradients.training.models.sg_module import SgModule
    9. from super_gradients.training.models.results import Results, DetectionResults
    10. from super_gradients.training.models.predictions import Prediction, DetectionPrediction
    11. from super_gradients.training.transforms.processing import Processing, ComposeProcessing
    12. @contextmanager
    13. def eval_mode(model: SgModule) -> None:
    14. """Set a model in evaluation mode and deactivate gradient computation, undo at the end.
    15. :param model: The model to set in evaluation mode.
    16. """
    17. _starting_mode = model.training
    18. model.eval()
    19. with torch.no_grad():
    20. yield
    21. model.train(mode=_starting_mode)
    22. class Pipeline(ABC):
    23. """An abstract base class representing a processing pipeline for a specific task.
    24. The pipeline includes loading images, preprocessing, prediction, and postprocessing.
    25. :param model: The model used for making predictions.
    26. :param image_processor: A single image processor or a list of image processors for preprocessing and postprocessing the images.
    27. :param device: The device on which the model will be run. Defaults to "cpu". Use "cuda" for GPU support.
    28. """
    29. def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], device: Optional[str] = "cpu"):
    30. super().__init__()
    31. self.model = model.to(device)
    32. self.device = device
    33. if isinstance(image_processor, list):
    34. image_processor = ComposeProcessing(image_processor)
    35. self.image_processor = image_processor
    36. @abstractmethod
    37. def __call__(self, images: Union[ImageType, List[ImageType]]) -> Union[Results, Tuple[List[np.ndarray], List[Prediction]]]:
    38. """Apply the pipeline on images and return the result.
    39. :param images: Single image or a list of images of supported types.
    40. :return Results object containing the results of the prediction and the image.
    41. """
    42. return self._run(images=images)
    43. def _run(self, images: Union[ImageType, List[ImageType]]) -> Tuple[List[np.ndarray], List[Prediction]]:
    44. """Run the pipeline and return (image, predictions). The pipeline is made of 4 steps:
    45. 1. Load images - Loading the images into a list of numpy arrays.
    46. 2. Preprocess - Encode the image in the shape/format expected by the model
    47. 3. Predict - Run the model on the preprocessed image
    48. 4. Postprocess - Decode the output of the model so that the predictions are in the shape/format of original image.
    49. :param images: Single image or a list of images of supported types.
    50. :return:
    51. - List of numpy arrays representing images.
    52. - List of model predictions.
    53. """
    54. self.model = self.model.to(self.device) # Make sure the model is on the correct device, as it might have been moved after init
    55. images = load_images(images)
    56. # Preprocess
    57. preprocessed_images, processing_metadatas = [], []
    58. for image in images:
    59. preprocessed_image, processing_metadata = self.image_processor.preprocess_image(image=image.copy())
    60. preprocessed_images.append(preprocessed_image)
    61. processing_metadatas.append(processing_metadata)
    62. # Predict
    63. with eval_mode(self.model):
    64. torch_inputs = torch.Tensor(np.array(preprocessed_images)).to(self.device)
    65. model_output = self.model(torch_inputs)
    66. predictions = self._decode_model_output(model_output, model_input=torch_inputs)
    67. # Postprocess
    68. postprocessed_predictions = []
    69. for prediction, processing_metadata in zip(predictions, processing_metadatas):
    70. prediction = self.image_processor.postprocess_predictions(predictions=prediction, metadata=processing_metadata)
    71. postprocessed_predictions.append(prediction)
    72. return images, postprocessed_predictions
    73. @abstractmethod
    74. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[Prediction]:
    75. """Decode the model outputs, move each prediction to numpy and store it in a Prediction object.
    76. :param model_output: Direct output of the model, without any post-processing.
    77. :param model_input: Model input (i.e. images after preprocessing).
    78. :return: Model predictions, without any post-processing.
    79. """
    80. pass
    81. class DetectionPipeline(Pipeline):
    82. """Pipeline specifically designed for object detection tasks.
    83. The pipeline includes loading images, preprocessing, prediction, and postprocessing.
    84. :param model: The object detection model (instance of SgModule) used for making predictions.
    85. :param class_names: List of class names corresponding to the model's output classes.
    86. :param post_prediction_callback: Callback function to process raw predictions from the model.
    87. :param image_processor: Single image processor or a list of image processors for preprocessing and postprocessing the images.
    88. :param device: The device on which the model will be run. Defaults to "cpu". Use "cuda" for GPU support.
    89. """
    90. def __init__(
    91. self,
    92. model: SgModule,
    93. class_names: List[str],
    94. post_prediction_callback: DetectionPostPredictionCallback,
    95. device: Optional[str] = "cpu",
    96. image_processor: Optional[Processing] = None,
    97. ):
    98. super().__init__(model=model, device=device, image_processor=image_processor)
    99. self.post_prediction_callback = post_prediction_callback
    100. self.class_names = class_names
    101. def __call__(self, images: Union[List[ImageType], ImageType]) -> DetectionResults:
    102. """Apply the pipeline on images and return the detection result.
    103. :param images: Single image or a list of images of supported types.
    104. :return Results object containing the results of the prediction and the image.
    105. """
    106. images, predictions = super().__call__(images=images)
    107. return DetectionResults(images=images, predictions=predictions, class_names=self.class_names)
    108. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]:
    109. """Decode the model output, by applying post prediction callback. This includes NMS.
    110. :param model_output: Direct output of the model, without any post-processing.
    111. :param model_input: Model input (i.e. images after preprocessing).
    112. :return: Predicted Bboxes.
    113. """
    114. post_nms_predictions = self.post_prediction_callback(model_output, device=self.device)
    115. predictions = []
    116. for prediction, image in zip(post_nms_predictions, model_input):
    117. prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32)
    118. prediction = prediction.detach().cpu().numpy()
    119. predictions.append(
    120. DetectionPrediction(
    121. bboxes=prediction[:, :4],
    122. confidence=prediction[:, 4],
    123. labels=prediction[:, 5],
    124. bbox_format="xyxy",
    125. image_shape=image.shape,
    126. )
    127. )
    128. return predictions
    Discard
    @@ -1,4 +1,4 @@
    -from typing import Tuple, List, Union
    +from typing import Tuple, List, Union, Optional
     from abc import ABC, abstractmethod
     from abc import ABC, abstractmethod
     from dataclasses import dataclass
     from dataclasses import dataclass
     
     
    @@ -202,3 +202,43 @@ class DetectionLongestMaxSizeRescale(_LongestMaxSizeRescale):
         def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction:
         def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction:
             predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))
             predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))
             return predictions
             return predictions
    +
    +
    +def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -> Tuple[Optional[List[str]], Optional[Processing]]:
    +    """Get the processing parameters for a pretrained model."""
    +    if "yolox" in model_name and pretrained_weights == "coco":
    +        return default_yolox_coco_processing_params()
    +    elif "ppyoloe" in model_name and pretrained_weights == "coco":
    +        return default_ppyoloe_coco_processing_params()
    +    else:
    +        return None, None
    +
    +
    +def default_yolox_coco_processing_params() -> Tuple[List[str], Processing]:
    +    """Processing parameters commonly used for training YoloX on COCO dataset."""
    +    from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
    +
    +    image_processor = ComposeProcessing(
    +        [
    +            DetectionLongestMaxSizeRescale((640, 640)),
    +            DetectionBottomRightPadding((640, 640), 114),
    +            ImagePermute((2, 0, 1)),
    +        ]
    +    )
    +    class_names = COCO_DETECTION_CLASSES_LIST
    +    return class_names, image_processor
    +
    +
    +def default_ppyoloe_coco_processing_params() -> Tuple[List[str], Processing]:
    +    """Processing parameters commonly used for training PPYoloE on COCO dataset."""
    +    from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
    +
    +    image_processor = ComposeProcessing(
    +        [
    +            DetectionRescale(output_shape=(640, 640)),
    +            NormalizeImage(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
    +            ImagePermute(permutation=(2, 0, 1)),
    +        ]
    +    )
    +    class_names = COCO_DETECTION_CLASSES_LIST
    +    return class_names, image_processor
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    1. from typing import Union, List
    2. import PIL
    3. import numpy as np
    4. import torch
    5. import requests
    6. from urllib.parse import urlparse
    7. ImageType = Union[str, np.ndarray, torch.Tensor, PIL.Image.Image]
    8. def load_images(images: Union[List[ImageType], ImageType]) -> List[np.ndarray]:
    9. """Load a single image or a list of images and return them as a list of numpy arrays.
    10. Supported image types include:
    11. - numpy.ndarray: A numpy array representing the image
    12. - torch.Tensor: A PyTorch tensor representing the image
    13. - PIL.Image.Image: A PIL Image object
    14. - str: A string representing either a local file path or a URL to an image
    15. :param images: Single image or a list of images of supported types.
    16. :return: List of images as numpy arrays. If loaded from string, the image will be returned as RGB.
    17. """
    18. if isinstance(images, list):
    19. return [load_image(image=image) for image in images]
    20. else:
    21. return [load_image(image=images)]
    22. def load_image(image: ImageType) -> np.ndarray:
    23. """Load a single image and return it as a numpy arrays.
    24. Supported image types include:
    25. - numpy.ndarray: A numpy array representing the image
    26. - torch.Tensor: A PyTorch tensor representing the image
    27. - PIL.Image.Image: A PIL Image object
    28. - str: A string representing either a local file path or a URL to an image
    29. :param image: Single image of supported types.
    30. :return: Image as numpy arrays. If loaded from string, the image will be returned as RGB.
    31. """
    32. if isinstance(image, np.ndarray):
    33. return image
    34. elif isinstance(image, torch.Tensor):
    35. return image.numpy()
    36. elif isinstance(image, PIL.Image.Image):
    37. return load_np_image_from_pil(image)
    38. elif isinstance(image, str):
    39. image = load_pil_image_from_str(image_str=image)
    40. return load_np_image_from_pil(image)
    41. else:
    42. raise ValueError(f"Unsupported image type: {type(image)}")
    43. def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray:
    44. """Convert a PIL image to numpy array in RGB format."""
    45. return np.asarray(image.convert("RGB"))
    46. def load_pil_image_from_str(image_str: str) -> PIL.Image.Image:
    47. """Load an image based on a string (local file path or URL)."""
    48. if is_url(image_str):
    49. response = requests.get(image_str, stream=True)
    50. response.raise_for_status()
    51. return PIL.Image.open(response.raw)
    52. else:
    53. return PIL.Image.open(image_str)
    54. def is_url(url: str) -> bool:
    55. """Check if the given string is a URL."""
    56. try:
    57. result = urlparse(url)
    58. return all([result.scheme, result.netloc, result.path])
    59. except Exception:
    60. return False
    Discard