Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#817 Added tutorial on DetectionOutputAdapter

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-744_DetectionOutputAdapter_Docs

DetectionOutputAdapter

The DetectionOutputAdapter is a class that converts the output of a detection model into a user-appropriate format. For instance, it can be used to convert the format of bounding boxes from CYXHW to XYXY, or to change the layout of the elements in the output tensor from [X1, Y1, X2, Y2, Confidence, Class] to [Class, Confidence, X1, Y1, X2, Y2].

Features

  • Easy rearrangement of the elements in the output tensor
  • Easy conversion of the bounding box format
  • Support of JIT Tracing & Scripting
  • Support of ONNX export

Usage

We start by introducing the concept of a format. A format represents a specific layout of the elements in the output tensor. Currently, there is only one type of formats supported - ConcatenatedTensorFormat which represents a layout where all predictions concatenated into a single tensor. Additional formats can be added in the future (Like DictionaryOfTensorsFormat).

ConcatenatedTensorFormat requires that input is a tensor and has the following shape:

  • Tensor of shape [N, Elements] - N is the number of predictions, Elements is the concatenated vector of attributes per box.
  • Tensor of shape [B, N, Elements] - B is the batch dimension, N and Elements as above.

To instantiate the DetectionOutputAdapter we have to describe the input and output formats for our predictions:

Let's imagine model emits predictions in the following format:

# [N, 10] (cx, cy, w, h, class, confidence, attributes..)
example_input = [
    #      cx          cy        w          h     class, confidence,   attribute a, attribute b, attribute c, attribute d
    [0.465625,  0.5625,    0.13125,   0.125,          0,      0.968,         0.350,       0.643,       0.640,       0.453],
    [0.103125,  0.1671875, 0.10625,   0.134375,       1,      0.897,         0.765,       0.654,       0.324,       0.816],
    [0.078125,  0.078125,  0.15625,   0.15625,        2.,     0.423,         0.792,       0.203,       0.653,       0.777],
    ...
]

The corresponding format definition would look like this:

from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem, NormalizedCXCYWHCoordinateFormat

input_format = ConcatenatedTensorFormat(
    layout=(
        BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
        TensorSliceItem(name="class", length=1),
        TensorSliceItem(name="confidence", length=1),
        TensorSliceItem(name="attributes", length=4),
    )
)

For sake of demonstration, let's assume that we want to convert the output to the following format:

# [N, 10] (class, attributes, x1, y1, x2, y2)
[
    # class, attribute a, attribute b, attribute c, attribute d,     x1,   y1,   x2,    y2
    [     0,       0.350,       0.643,       0.640,       0.453,    256,  320,  340,   400],
    [     1,       0.765,       0.654,       0.324,       0.816,     32,   64,  100,   150],
    [     2,       0.792,       0.203,       0.653,       0.777,      0,    0,  100,   100],
    ...
]
  • The class and attributes are the same as in the input format but comes first
  • The format of bounding boxes is changed from NormalizedCXCYWHCoordinateFormat to XYXYCoordinateFormat
  • The confidence is removed from the output

The corresponding format definition would look like this:

from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem, XYXYCoordinateFormat

output_format = ConcatenatedTensorFormat(
    layout=(
        TensorSliceItem(name="class", length=1),
        TensorSliceItem(name="attributes", length=4),
        BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
    )
)

Now we can construct the DetectionOutputAdapter and attach it to the model:

from super_gradients.training.datasets.data_formats import DetectionOutputAdapter

output_adapter = DetectionOutputAdapter(input_format, output_format, image_shape=(640,640))

model = nn.Sequential(
    create_model(),
    create_nms(),
    output_adapter
)

To test how the output adapter transforms dummy input one can easily run it alone:

output = output_adapter(torch.from_numpy(example_input)).numpy()
print(output)

# Prints:
[
    # class,   attribute a, attribute b, attribute c, attribute d,     x1,   y1,   x2,    y2
    [     0,         0.350,       0.643,       0.640,       0.453,    256,  320,  340,   400], 
    [     1,         0.765,       0.654,       0.324,       0.816,     32,   64,  100,   150], 
    [     2,         0.792,       0.203,       0.653,       0.777,      0,    0,  100,   100]
]

Not supported features

Currently DetectionOutputAdapter does not support the following features:

  • argmax operation over a slice of confidences for [C] classes (Useful to compute argmax(class confidences))
  • Multiplication of two slices (Useful to compute confidence * class)
Discard
@@ -1,5 +1,27 @@
 from .format_converter import ConcatenatedTensorFormatConverter
 from .format_converter import ConcatenatedTensorFormatConverter
 from .output_adapters import DetectionOutputAdapter
 from .output_adapters import DetectionOutputAdapter
 from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
 from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
+from .bbox_formats import (
+    CXCYWHCoordinateFormat,
+    NormalizedCXCYWHCoordinateFormat,
+    NormalizedXYWHCoordinateFormat,
+    NormalizedXYXYCoordinateFormat,
+    XYWHCoordinateFormat,
+    XYXYCoordinateFormat,
+    YXYXCoordinateFormat,
+)
 
 
-__all__ = ["ConcatenatedTensorFormatConverter", "DetectionOutputAdapter", "TensorSliceItem", "ConcatenatedTensorFormat", "BoundingBoxesTensorSliceItem"]
+__all__ = [
+    "BoundingBoxesTensorSliceItem",
+    "CXCYWHCoordinateFormat",
+    "ConcatenatedTensorFormat",
+    "ConcatenatedTensorFormatConverter",
+    "DetectionOutputAdapter",
+    "NormalizedCXCYWHCoordinateFormat",
+    "NormalizedXYWHCoordinateFormat",
+    "NormalizedXYXYCoordinateFormat",
+    "TensorSliceItem",
+    "XYWHCoordinateFormat",
+    "XYXYCoordinateFormat",
+    "YXYXCoordinateFormat",
+]
Discard
@@ -114,7 +114,7 @@ class DetectionOutputAdapter(nn.Module):
     >>> )
     >>> )
     >>>
     >>>
     >>> # Now we can construct output adapter and attach it to the model
     >>> # Now we can construct output adapter and attach it to the model
-    >>> output_adapter = DetectionOutputAdapter(yolox,
+    >>> output_adapter = DetectionOutputAdapter(
     >>>     input_format=yolox.head.format,
     >>>     input_format=yolox.head.format,
     >>>     output_format=output_format,
     >>>     output_format=output_format,
     >>>     image_shape=(640, 640)
     >>>     image_shape=(640, 640)
@@ -133,14 +133,16 @@ class DetectionOutputAdapter(nn.Module):
                             If you're not using normalized coordinates you can set this to None
                             If you're not using normalized coordinates you can set this to None
         """
         """
         super().__init__()
         super().__init__()
-        self.rearrange_outputs, rearranged_format = self.get_rearrange_outputs_module(input_format, output_format)
 
 
         self.format_conversion: nn.Module = self.get_format_conversion_module(
         self.format_conversion: nn.Module = self.get_format_conversion_module(
-            location=rearranged_format.locations[rearranged_format.bboxes_format.name],
-            input_bbox_format=rearranged_format.bboxes_format.format,
+            location=input_format.locations[input_format.bboxes_format.name],
+            input_bbox_format=input_format.bboxes_format.format,
             output_bbox_format=output_format.bboxes_format.format,
             output_bbox_format=output_format.bboxes_format.format,
             image_shape=image_shape,
             image_shape=image_shape,
         )
         )
+
+        self.rearrange_outputs, rearranged_format = self.get_rearrange_outputs_module(input_format, output_format)
+
         self.input_format = input_format
         self.input_format = input_format
         self.output_format = output_format
         self.output_format = output_format
         self.input_length = input_format.num_channels
         self.input_length = input_format.num_channels
@@ -157,8 +159,8 @@ class DetectionOutputAdapter(nn.Module):
                 f"equal to {self.input_length} as defined by input format."
                 f"equal to {self.input_length} as defined by input format."
             )
             )
 
 
+        predictions = self.format_conversion(predictions.clone())
         predictions = self.rearrange_outputs(predictions)
         predictions = self.rearrange_outputs(predictions)
-        predictions = self.format_conversion(predictions)
         return predictions
         return predictions
 
 
     @classmethod
     @classmethod
Discard
@@ -7,14 +7,20 @@ import onnx
 import onnxruntime as ort
 import onnxruntime as ort
 import torch.jit
 import torch.jit
 
 
-from super_gradients.training.datasets.data_formats.bbox_formats import NormalizedXYWHCoordinateFormat, CXCYWHCoordinateFormat, YXYXCoordinateFormat
-from super_gradients.training.datasets.data_formats.output_adapters.detection_adapter import DetectionOutputAdapter
 from super_gradients.training.datasets.data_formats import (
 from super_gradients.training.datasets.data_formats import (
     ConcatenatedTensorFormat,
     ConcatenatedTensorFormat,
     BoundingBoxesTensorSliceItem,
     BoundingBoxesTensorSliceItem,
     TensorSliceItem,
     TensorSliceItem,
+    XYXYCoordinateFormat,
+    NormalizedXYWHCoordinateFormat,
+    CXCYWHCoordinateFormat,
+    YXYXCoordinateFormat,
+    NormalizedCXCYWHCoordinateFormat,
+    DetectionOutputAdapter,
 )
 )
 
 
+from super_gradients.training.datasets.data_formats.bbox_formats.normalized_cxcywh import xyxy_to_normalized_cxcywh
+
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
     layout=(
     layout=(
         BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
         BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
@@ -135,6 +141,43 @@ class TestDetectionOutputAdapter(unittest.TestCase):
 
 
             np.testing.assert_allclose(actual_output, expected_output)
             np.testing.assert_allclose(actual_output, expected_output)
 
 
+    def test_output_adapter_manual_case(self):
+
+        image_shape = 640, 640
+
+        expected_bboxes_xyxy = np.array(
+            [
+                [256, 320, 340, 400],
+                [32, 64, 100, 150],
+                [0, 0, 100, 100],
+            ]
+        )
+
+        input_bboxes_cxcywh = xyxy_to_normalized_cxcywh(expected_bboxes_xyxy, image_shape)
+        input_labels = np.arange(len(expected_bboxes_xyxy))
+        input = torch.from_numpy(np.concatenate([input_bboxes_cxcywh, input_labels[:, None]], axis=-1))
+        print(input.numpy())
+
+        input_format = ConcatenatedTensorFormat(
+            layout=(
+                BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
+                TensorSliceItem(name="class", length=1),
+            )
+        )
+
+        output_format = ConcatenatedTensorFormat(
+            layout=(
+                TensorSliceItem(name="class", length=1),
+                BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
+            )
+        )
+
+        output_adapter = DetectionOutputAdapter(input_format, output_format, image_shape)
+        output = output_adapter(input)
+        output_bboxes = output[:, 1:].numpy()
+        print(output.numpy())
+        np.testing.assert_allclose(output_bboxes, expected_bboxes_xyxy)
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard