Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#817 Added tutorial on DetectionOutputAdapter

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-744_DetectionOutputAdapter_Docs
1 changed files with 45 additions and 2 deletions
  1. 45
    2
      tests/unit_tests/detection_output_adapter_test.py
@@ -7,14 +7,20 @@ import onnx
 import onnxruntime as ort
 import onnxruntime as ort
 import torch.jit
 import torch.jit
 
 
-from super_gradients.training.datasets.data_formats.bbox_formats import NormalizedXYWHCoordinateFormat, CXCYWHCoordinateFormat, YXYXCoordinateFormat
-from super_gradients.training.datasets.data_formats.output_adapters.detection_adapter import DetectionOutputAdapter
 from super_gradients.training.datasets.data_formats import (
 from super_gradients.training.datasets.data_formats import (
     ConcatenatedTensorFormat,
     ConcatenatedTensorFormat,
     BoundingBoxesTensorSliceItem,
     BoundingBoxesTensorSliceItem,
     TensorSliceItem,
     TensorSliceItem,
+    XYXYCoordinateFormat,
+    NormalizedXYWHCoordinateFormat,
+    CXCYWHCoordinateFormat,
+    YXYXCoordinateFormat,
+    NormalizedCXCYWHCoordinateFormat,
+    DetectionOutputAdapter,
 )
 )
 
 
+from super_gradients.training.datasets.data_formats.bbox_formats.normalized_cxcywh import xyxy_to_normalized_cxcywh
+
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
     layout=(
     layout=(
         BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
         BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
@@ -135,6 +141,43 @@ class TestDetectionOutputAdapter(unittest.TestCase):
 
 
             np.testing.assert_allclose(actual_output, expected_output)
             np.testing.assert_allclose(actual_output, expected_output)
 
 
+    def test_output_adapter_manual_case(self):
+
+        image_shape = 640, 640
+
+        expected_bboxes_xyxy = np.array(
+            [
+                [256, 320, 340, 400],
+                [32, 64, 100, 150],
+                [0, 0, 100, 100],
+            ]
+        )
+
+        input_bboxes_cxcywh = xyxy_to_normalized_cxcywh(expected_bboxes_xyxy, image_shape)
+        input_labels = np.arange(len(expected_bboxes_xyxy))
+        input = torch.from_numpy(np.concatenate([input_bboxes_cxcywh, input_labels[:, None]], axis=-1))
+        print(input.numpy())
+
+        input_format = ConcatenatedTensorFormat(
+            layout=(
+                BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
+                TensorSliceItem(name="class", length=1),
+            )
+        )
+
+        output_format = ConcatenatedTensorFormat(
+            layout=(
+                TensorSliceItem(name="class", length=1),
+                BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
+            )
+        )
+
+        output_adapter = DetectionOutputAdapter(input_format, output_format, image_shape)
+        output = output_adapter(input)
+        output_bboxes = output[:, 1:].numpy()
+        print(output.numpy())
+        np.testing.assert_allclose(output_bboxes, expected_bboxes_xyxy)
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard