Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#561 Feature/sg 193 extend output formator

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-193-extend_detection_target_transform
27 changed files with 506 additions and 230 deletions
  1. 1
    1
      src/super_gradients/common/factories/bbox_format_factory.py
  2. 7
    0
      src/super_gradients/common/factories/data_formats_factory.py
  3. 15
    1
      src/super_gradients/common/object_names.py
  4. 4
    6
      src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml
  5. 5
    6
      src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml
  6. 5
    0
      src/super_gradients/training/datasets/data_formats/__init__.py
  7. 0
    0
      src/super_gradients/training/datasets/data_formats/bbox_formats/__init__.py
  8. 0
    0
      src/super_gradients/training/datasets/data_formats/bbox_formats/bbox_format.py
  9. 1
    1
      src/super_gradients/training/datasets/data_formats/bbox_formats/cxcywh.py
  10. 3
    3
      src/super_gradients/training/datasets/data_formats/bbox_formats/normalized_cxcywh.py
  11. 3
    3
      src/super_gradients/training/datasets/data_formats/bbox_formats/normalized_xywh.py
  12. 1
    1
      src/super_gradients/training/datasets/data_formats/bbox_formats/normalized_xyxy.py
  13. 1
    1
      src/super_gradients/training/datasets/data_formats/bbox_formats/xywh.py
  14. 1
    1
      src/super_gradients/training/datasets/data_formats/bbox_formats/xyxy.py
  15. 1
    1
      src/super_gradients/training/datasets/data_formats/bbox_formats/yxyx.py
  16. 103
    0
      src/super_gradients/training/datasets/data_formats/default_formats.py
  17. 63
    0
      src/super_gradients/training/datasets/data_formats/format_converter.py
  18. 173
    0
      src/super_gradients/training/datasets/data_formats/formats.py
  19. 3
    0
      src/super_gradients/training/datasets/data_formats/output_adapters/__init__.py
  20. 4
    4
      src/super_gradients/training/datasets/data_formats/output_adapters/detection_adapter.py
  21. 0
    2
      src/super_gradients/training/transforms/all_transforms.py
  22. 42
    65
      src/super_gradients/training/transforms/transforms.py
  23. 0
    4
      src/super_gradients/training/utils/output_adapters/__init__.py
  24. 0
    78
      src/super_gradients/training/utils/output_adapters/formats.py
  25. 8
    7
      tests/unit_tests/bbox_formats_test.py
  26. 8
    3
      tests/unit_tests/detection_output_adapter_test.py
  27. 54
    42
      tests/unit_tests/detection_targets_format_transform_test.py
@@ -1,5 +1,5 @@
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
-from super_gradients.training.utils.bbox_formats import BBOX_FORMATS
+from super_gradients.training.datasets.data_formats.bbox_formats import BBOX_FORMATS
 
 
 
 
 class BBoxFormatFactory(BaseFactory):
 class BBoxFormatFactory(BaseFactory):
Discard
1
2
3
4
5
6
7
  1. from super_gradients.common.factories.type_factory import TypeFactory
  2. from super_gradients.training.datasets.data_formats.default_formats import DEFAULT_CONCATENATED_TENSOR_FORMATS
  3. class ConcatenatedTensorFormatFactory(TypeFactory):
  4. def __init__(self):
  5. super().__init__(DEFAULT_CONCATENATED_TENSOR_FORMATS)
Discard
@@ -50,7 +50,6 @@ class Transforms:
     DetectionHSV = "DetectionHSV"
     DetectionHSV = "DetectionHSV"
     DetectionHorizontalFlip = "DetectionHorizontalFlip"
     DetectionHorizontalFlip = "DetectionHorizontalFlip"
     DetectionPaddedRescale = "DetectionPaddedRescale"
     DetectionPaddedRescale = "DetectionPaddedRescale"
-    DetectionTargetsFormat = "DetectionTargetsFormat"
     DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
     DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
     RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation"
     RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation"
     RandAugmentTransform = "RandAugmentTransform"
     RandAugmentTransform = "RandAugmentTransform"
@@ -253,3 +252,18 @@ class Models:
     PP_LITE_B_SEG75 = "pp_lite_b_seg75"
     PP_LITE_B_SEG75 = "pp_lite_b_seg75"
     UNET_CUSTOM = "unet_custom"
     UNET_CUSTOM = "unet_custom"
     UNET_CUSTOM_CLS = "unet_custom_cls"
     UNET_CUSTOM_CLS = "unet_custom_cls"
+
+
+class ConcatenatedTensorFormats:
+    XYXY_LABEL = "XYXY_LABEL"
+    XYWH_LABEL = "XYWH_LABEL"
+    CXCYWH_LABEL = "CXCYWH_LABEL"
+    LABEL_XYXY = "LABEL_XYXY"
+    LABEL_XYWH = "LABEL_XYWH"
+    LABEL_CXCYWH = "LABEL_CXCYWH"
+    NORMALIZED_XYXY_LABEL = "NORMALIZED_XYXY_LABEL"
+    NORMALIZED_XYWH_LABEL = "NORMALIZED_XYWH_LABEL"
+    NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"
+    LABEL_NORMALIZED_XYXY = "LABEL_NORMALIZED_XYXY"
+    LABEL_NORMALIZED_XYWH = "LABEL_NORMALIZED_XYWH"
+    LABEL_NORMALIZED_CXCYWH = "LABEL_NORMALIZED_CXCYWH"
Discard
@@ -35,9 +35,8 @@ train_dataset_params:
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         max_targets: 120
         max_targets: 120
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
-        output_format:
-          _target_: super_gradients.training.utils.detection_utils.DetectionTargetsFormat # targets format
-          value: LABEL_CXCYWH
+        image_shape: ${dataset_params.train_dataset_params.input_dim}
+        output_format: LABEL_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
   max_num_samples:
   max_num_samples:
@@ -70,10 +69,9 @@ val_dataset_params:
   - DetectionPaddedRescale:
   - DetectionPaddedRescale:
       input_dim: ${dataset_params.val_dataset_params.input_dim}
       input_dim: ${dataset_params.val_dataset_params.input_dim}
   - DetectionTargetsFormatTransform:
   - DetectionTargetsFormatTransform:
+      image_shape: ${dataset_params.val_dataset_params.input_dim}
       max_targets: 50
       max_targets: 50
-      output_format:
-        _target_: super_gradients.training.utils.detection_utils.DetectionTargetsFormat # targets format
-        value: LABEL_CXCYWH
+      output_format: LABEL_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
   max_num_samples:
   max_num_samples:
Discard
@@ -30,9 +30,9 @@ train_dataset_params:
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         max_targets: 120
         max_targets: 120
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
-        output_format:
-          _target_: super_gradients.training.utils.detection_utils.DetectionTargetsFormat
-          value: LABEL_NORMALIZED_CXCYWH
+        image_shape: ${dataset_params.train_dataset_params.input_dim}
+        max_targets: 50
+        output_format: LABEL_NORMALIZED_CXCYWH
 
 
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
@@ -65,10 +65,9 @@ val_dataset_params:
     - DetectionPaddedRescale:
     - DetectionPaddedRescale:
         input_dim: ${dataset_params.val_dataset_params.input_dim}
         input_dim: ${dataset_params.val_dataset_params.input_dim}
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
+        image_shape: ${dataset_params.val_dataset_params.input_dim}
         max_targets: 50
         max_targets: 50
-        output_format:
-          _target_: super_gradients.training.utils.detection_utils.DetectionTargetsFormat
-          value: LABEL_NORMALIZED_CXCYWH
+        output_format: LABEL_NORMALIZED_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
   max_num_samples:
   max_num_samples:
Discard
1
2
3
4
5
  1. from .format_converter import ConcatenatedTensorFormatConverter
  2. from .output_adapters import DetectionOutputAdapter
  3. from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
  4. __all__ = ["ConcatenatedTensorFormatConverter", "DetectionOutputAdapter", "TensorSliceItem", "ConcatenatedTensorFormat", "BoundingBoxesTensorSliceItem"]
Discard
Discard
Discard
@@ -4,7 +4,7 @@ from typing import Tuple
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-from super_gradients.training.utils.bbox_formats.bbox_format import (
+from super_gradients.training.datasets.data_formats.bbox_formats.bbox_format import (
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
 
 
Discard
@@ -1,10 +1,10 @@
 from typing import Tuple
 from typing import Tuple
 
 
-from super_gradients.training.utils.bbox_formats.bbox_format import (
+from super_gradients.training.datasets.data_formats.bbox_formats.bbox_format import (
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
-from super_gradients.training.utils.bbox_formats.cxcywh import cxcywh_to_xyxy, xyxy_to_cxcywh_inplace, cxcywh_to_xyxy_inplace
-from super_gradients.training.utils.bbox_formats.normalized_xyxy import (
+from super_gradients.training.datasets.data_formats.bbox_formats.cxcywh import cxcywh_to_xyxy, xyxy_to_cxcywh_inplace, cxcywh_to_xyxy_inplace
+from super_gradients.training.datasets.data_formats.bbox_formats.normalized_xyxy import (
     xyxy_to_normalized_xyxy_inplace,
     xyxy_to_normalized_xyxy_inplace,
     xyxy_to_normalized_xyxy,
     xyxy_to_normalized_xyxy,
     normalized_xyxy_to_xyxy_inplace,
     normalized_xyxy_to_xyxy_inplace,
Discard
@@ -1,14 +1,14 @@
 from typing import Tuple
 from typing import Tuple
 
 
-from super_gradients.training.utils.bbox_formats.bbox_format import (
+from super_gradients.training.datasets.data_formats.bbox_formats.bbox_format import (
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
-from super_gradients.training.utils.bbox_formats.normalized_xyxy import (
+from super_gradients.training.datasets.data_formats.bbox_formats.normalized_xyxy import (
     normalized_xyxy_to_xyxy_inplace,
     normalized_xyxy_to_xyxy_inplace,
     xyxy_to_normalized_xyxy_inplace,
     xyxy_to_normalized_xyxy_inplace,
     xyxy_to_normalized_xyxy,
     xyxy_to_normalized_xyxy,
 )
 )
-from super_gradients.training.utils.bbox_formats.xywh import xywh_to_xyxy_inplace, xywh_to_xyxy, xyxy_to_xywh_inplace
+from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy_inplace, xywh_to_xyxy, xyxy_to_xywh_inplace
 
 
 __all__ = [
 __all__ = [
     "xyxy_to_normalized_xywh",
     "xyxy_to_normalized_xywh",
Discard
@@ -5,7 +5,7 @@ import numpy as np
 import torch
 import torch
 from torch import Tensor
 from torch import Tensor
 
 
-from super_gradients.training.utils.bbox_formats.bbox_format import (
+from super_gradients.training.datasets.data_formats.bbox_formats.bbox_format import (
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
 
 
Discard
@@ -3,7 +3,7 @@ from typing import Tuple
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-from super_gradients.training.utils.bbox_formats.bbox_format import (
+from super_gradients.training.datasets.data_formats.bbox_formats.bbox_format import (
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
 
 
Discard
@@ -1,6 +1,6 @@
 from typing import Tuple
 from typing import Tuple
 
 
-from super_gradients.training.utils.bbox_formats.bbox_format import (
+from super_gradients.training.datasets.data_formats.bbox_formats.bbox_format import (
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
 
 
Discard
@@ -3,7 +3,7 @@ from typing import Tuple
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-from super_gradients.training.utils.bbox_formats.bbox_format import BoundingBoxFormat
+from super_gradients.training.datasets.data_formats.bbox_formats.bbox_format import BoundingBoxFormat
 
 
 __all__ = ["YXYXCoordinateFormat", "xyxy_to_yxyx", "xyxy_to_yxyx_inplace"]
 __all__ = ["YXYXCoordinateFormat", "xyxy_to_yxyx", "xyxy_to_yxyx_inplace"]
 
 
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
  1. from super_gradients.common.object_names import ConcatenatedTensorFormats
  2. from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
  3. from super_gradients.training.datasets.data_formats.bbox_formats import (
  4. XYXYCoordinateFormat,
  5. XYWHCoordinateFormat,
  6. CXCYWHCoordinateFormat,
  7. NormalizedXYXYCoordinateFormat,
  8. NormalizedXYWHCoordinateFormat,
  9. NormalizedCXCYWHCoordinateFormat,
  10. )
  11. XYXY_LABEL = ConcatenatedTensorFormat(
  12. layout=(
  13. BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
  14. TensorSliceItem(length=1, name="labels"),
  15. )
  16. )
  17. XYWH_LABEL = ConcatenatedTensorFormat(
  18. layout=(
  19. BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
  20. TensorSliceItem(length=1, name="labels"),
  21. )
  22. )
  23. CXCYWH_LABEL = ConcatenatedTensorFormat(
  24. layout=(
  25. BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
  26. TensorSliceItem(length=1, name="labels"),
  27. )
  28. )
  29. LABEL_XYXY = ConcatenatedTensorFormat(
  30. layout=(
  31. TensorSliceItem(length=1, name="labels"),
  32. BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
  33. )
  34. )
  35. LABEL_XYWH = ConcatenatedTensorFormat(
  36. layout=(
  37. TensorSliceItem(length=1, name="labels"),
  38. BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
  39. )
  40. )
  41. LABEL_CXCYWH = ConcatenatedTensorFormat(
  42. layout=(
  43. TensorSliceItem(length=1, name="labels"),
  44. BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
  45. )
  46. )
  47. NORMALIZED_XYXY_LABEL = ConcatenatedTensorFormat(
  48. layout=(
  49. BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
  50. TensorSliceItem(length=1, name="labels"),
  51. )
  52. )
  53. NORMALIZED_XYWH_LABEL = ConcatenatedTensorFormat(
  54. layout=(
  55. BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
  56. TensorSliceItem(length=1, name="labels"),
  57. )
  58. )
  59. NORMALIZED_CXCYWH_LABEL = ConcatenatedTensorFormat(
  60. layout=(
  61. BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
  62. TensorSliceItem(length=1, name="labels"),
  63. )
  64. )
  65. LABEL_NORMALIZED_XYXY = ConcatenatedTensorFormat(
  66. layout=(
  67. TensorSliceItem(length=1, name="labels"),
  68. BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
  69. )
  70. )
  71. LABEL_NORMALIZED_XYWH = ConcatenatedTensorFormat(
  72. layout=(
  73. TensorSliceItem(length=1, name="labels"),
  74. BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
  75. )
  76. )
  77. LABEL_NORMALIZED_CXCYWH = ConcatenatedTensorFormat(
  78. layout=(
  79. TensorSliceItem(length=1, name="labels"),
  80. BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
  81. )
  82. )
  83. def get_default_data_format(format_name: str) -> ConcatenatedTensorFormat:
  84. return DEFAULT_CONCATENATED_TENSOR_FORMATS[format_name]
  85. DEFAULT_CONCATENATED_TENSOR_FORMATS = {
  86. ConcatenatedTensorFormats.XYXY_LABEL: XYXY_LABEL,
  87. ConcatenatedTensorFormats.XYWH_LABEL: XYWH_LABEL,
  88. ConcatenatedTensorFormats.CXCYWH_LABEL: CXCYWH_LABEL,
  89. ConcatenatedTensorFormats.LABEL_XYXY: LABEL_XYXY,
  90. ConcatenatedTensorFormats.LABEL_XYWH: LABEL_XYWH,
  91. ConcatenatedTensorFormats.LABEL_CXCYWH: LABEL_CXCYWH,
  92. ConcatenatedTensorFormats.NORMALIZED_XYXY_LABEL: NORMALIZED_XYXY_LABEL,
  93. ConcatenatedTensorFormats.NORMALIZED_XYWH_LABEL: NORMALIZED_XYWH_LABEL,
  94. ConcatenatedTensorFormats.NORMALIZED_CXCYWH_LABEL: NORMALIZED_CXCYWH_LABEL,
  95. ConcatenatedTensorFormats.LABEL_NORMALIZED_XYXY: LABEL_NORMALIZED_XYXY,
  96. ConcatenatedTensorFormats.LABEL_NORMALIZED_XYWH: LABEL_NORMALIZED_XYWH,
  97. ConcatenatedTensorFormats.LABEL_NORMALIZED_CXCYWH: LABEL_NORMALIZED_CXCYWH,
  98. }
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  1. from typing import Tuple, Union
  2. import numpy as np
  3. from torch import Tensor
  4. from super_gradients.training.datasets.data_formats.bbox_formats import convert_bboxes
  5. from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, apply_on_bboxes, get_permutation_indexes
  6. __all__ = ["ConcatenatedTensorFormatConverter"]
  7. class ConcatenatedTensorFormatConverter:
  8. def __init__(
  9. self,
  10. input_format: ConcatenatedTensorFormat,
  11. output_format: ConcatenatedTensorFormat,
  12. image_shape: Union[Tuple[int, int], None],
  13. ):
  14. """
  15. Converts concatenated tensors from input format to output format.
  16. Example:
  17. >>> from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter
  18. >>> from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH, LABEL_NORMALIZED_XYXY
  19. >>> h, w = 100, 200
  20. >>> input_target = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
  21. >>> expected_output_target = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
  22. >>>
  23. >>> transform = ConcatenatedTensorFormatConverter(input_format=LABEL_NORMALIZED_XYXY, output_format=LABEL_CXCYWH, image_shape=(h, w))
  24. >>>
  25. >>> # np.float32 approximation of multiplication/division can lead to uncertainty of up to 1e-7 in precision
  26. >>> assert np.allclose(transform(input_target), expected_output_target, atol=1e-6)
  27. :param input_format: Format definition of the inputs
  28. :param output_format: Format definition of the outputs
  29. :param image_shape: Shape of the input image (rows, cols), used for converting bbox coordinates from/to normalized format.
  30. If you're not using normalized coordinates you can set this to None
  31. """
  32. self.permutation_indexes = get_permutation_indexes(input_format, output_format)
  33. self.input_format = input_format
  34. self.output_format = output_format
  35. self.image_shape = image_shape
  36. self.input_length = input_format.num_channels
  37. def __call__(self, tensor: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]:
  38. if tensor.shape[-1] != self.input_length:
  39. raise RuntimeError(
  40. f"Number of channels in last dimension of input tensor ({tensor.shape[-1]}) must be "
  41. f"equal to {self.input_length} as defined by input format."
  42. )
  43. tensor = tensor[:, self.permutation_indexes]
  44. tensor = apply_on_bboxes(fn=self._convert_bbox, tensor=tensor, tensor_format=self.output_format)
  45. return tensor
  46. def _convert_bbox(self, bboxes: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]:
  47. return convert_bboxes(
  48. bboxes=bboxes,
  49. source_format=self.input_format.bboxes_format.format,
  50. target_format=self.output_format.bboxes_format.format,
  51. inplace=False,
  52. image_shape=self.image_shape,
  53. )
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  1. import collections
  2. from typing import Tuple, Union, List, Mapping, Callable
  3. import numpy as np
  4. from torch import Tensor
  5. from super_gradients.training.datasets.data_formats.bbox_formats import BoundingBoxFormat
  6. class DetectionOutputFormat:
  7. pass
  8. class TensorSliceItem:
  9. length: int
  10. name: str
  11. def __init__(self, name: str, length: int):
  12. self.name = name
  13. self.length = length
  14. def __repr__(self):
  15. return f"name={self.name} length={self.length}"
  16. class BoundingBoxesTensorSliceItem(TensorSliceItem):
  17. format: BoundingBoxFormat
  18. def __init__(self, name: str, format: BoundingBoxFormat):
  19. super().__init__(name, length=format.get_num_parameters())
  20. self.format = format
  21. def __repr__(self):
  22. return f"name={self.name} length={self.length} format={self.format}"
  23. class ConcatenatedTensorFormat(DetectionOutputFormat):
  24. """
  25. Define the output format that return a single tensor of shape [N,M] (N - number of detections,
  26. M - sum of bbox attributes) that is a concatenated from bbox coordinates and other fields.
  27. A layout defines the order of concatenated tensors. For instance:
  28. - layout: (bboxes, scores, labels) gives a Tensor that is product of torch.cat([bboxes, scores, labels], dim=1)
  29. - layout: (labels, bboxes) produce a Tensor from torch.cat([labels, bboxes], dim=1)
  30. """
  31. layout: Mapping[str, TensorSliceItem]
  32. locations: Mapping[str, Tuple[int, int]]
  33. indexes: Mapping[str, List[int]]
  34. num_channels: int
  35. @property
  36. def bboxes_format(self) -> BoundingBoxesTensorSliceItem:
  37. bbox_items = [x for x in self.layout.values() if isinstance(x, BoundingBoxesTensorSliceItem)]
  38. return bbox_items[0]
  39. def __init__(self, layout: Union[List[TensorSliceItem], Tuple[TensorSliceItem, ...]]):
  40. bbox_items = [x for x in layout if isinstance(x, BoundingBoxesTensorSliceItem)]
  41. if len(bbox_items) != 1:
  42. raise RuntimeError("Number of bounding box items must be strictly equal to 1")
  43. _layout = []
  44. _locations = []
  45. _indexes = []
  46. offset = 0
  47. for item in layout:
  48. location_indexes = list(range(offset, offset + item.length))
  49. location_slice = offset, offset + item.length
  50. _layout.append((item.name, item))
  51. _locations.append((item.name, location_slice))
  52. _indexes.append((item.name, location_indexes))
  53. offset += item.length
  54. self.layout = collections.OrderedDict(_layout)
  55. self.locations = collections.OrderedDict(_locations)
  56. self.indexes = collections.OrderedDict(_indexes)
  57. self.num_channels = offset
  58. def __repr__(self):
  59. return str(self.layout)
  60. def apply_on_bboxes(
  61. fn: Callable[[Union[np.ndarray, Tensor]], Union[np.ndarray, Tensor]],
  62. tensor: Union[np.ndarray, Tensor],
  63. tensor_format: ConcatenatedTensorFormat,
  64. ) -> Union[np.ndarray, Tensor]:
  65. """Apply inplace a function only on the bboxes of a concatenated tensor.
  66. :param fn: Function to apply on the bboxes.
  67. :param tensor: Concatenated tensor that include - among other - the bboxes.
  68. :param tensor_format: Format of the tensor, required to know the indexes of the bboxes.
  69. :return: Tensor, after applying INPLACE the fn on the bboxes
  70. """
  71. return apply_on_layout(fn=fn, tensor=tensor, tensor_format=tensor_format, layout_name=tensor_format.bboxes_format.name)
  72. def apply_on_layout(
  73. fn: Callable[[Union[np.ndarray, Tensor]], Union[np.ndarray, Tensor]],
  74. tensor: Union[np.ndarray, Tensor],
  75. tensor_format: ConcatenatedTensorFormat,
  76. layout_name: str,
  77. ) -> Union[np.ndarray, Tensor]:
  78. """Apply inplace a function only on a specific layout of a concatenated tensor.
  79. :param fn: Function to apply on the bboxes.
  80. :param tensor: Concatenated tensor that include - among other - the layout of interest.
  81. :param tensor_format: Format of the tensor, required to know the indexes of the layout.
  82. :param layout_name: Name of the layout of interest. It has to be defined in the tensor_format.
  83. :return: Tensor, after applying INPLACE the fn on the layout
  84. """
  85. location = slice(*iter(tensor_format.locations[layout_name]))
  86. result = fn(tensor[..., location])
  87. tensor[..., location] = result
  88. return tensor
  89. def filter_on_bboxes(
  90. fn: Callable[[Union[np.ndarray, Tensor]], Union[np.ndarray, Tensor]],
  91. tensor: Union[np.ndarray, Tensor],
  92. tensor_format: ConcatenatedTensorFormat,
  93. ) -> Union[np.ndarray, Tensor]:
  94. """Filter the tensor according to a condition on the bboxes.
  95. :param fn: Function to filter the bboxes (keep only True elements).
  96. :param tensor: Concatenated tensor that include - among other - the bboxes.
  97. :param tensor_format: Format of the tensor, required to know the indexes of the bboxes.
  98. :return: Tensor, after applying INPLACE the fn on the bboxes
  99. """
  100. return filter_on_layout(fn=fn, tensor=tensor, tensor_format=tensor_format, layout_name=tensor_format.bboxes_format.name)
  101. def filter_on_layout(
  102. fn: Callable[[Union[np.ndarray, Tensor]], Union[np.ndarray, Tensor]],
  103. tensor: Union[np.ndarray, Tensor],
  104. tensor_format: ConcatenatedTensorFormat,
  105. layout_name: str,
  106. ) -> Union[np.ndarray, Tensor]:
  107. """Filter the tensor according to a condition on a specific layout.
  108. :param fn: Function to filter the bboxes (keep only True elements).
  109. :param tensor: Concatenated tensor that include - among other - the layout of interest.
  110. :param tensor_format: Format of the tensor, required to know the indexes of the layout.
  111. :param layout_name: Name of the layout of interest. It has to be defined in the tensor_format.
  112. :return: Tensor, after filtering the bboxes according to fn.
  113. """
  114. location = slice(*tensor_format.locations[layout_name])
  115. mask = fn(tensor[..., location])
  116. tensor = tensor[mask]
  117. return tensor
  118. def get_permutation_indexes(input_format: ConcatenatedTensorFormat, output_format: ConcatenatedTensorFormat) -> List[int]:
  119. """Compute the permutations required to change the format layout order.
  120. :param input_format: Input format to transform from
  121. :param output_format: Output format to transform to
  122. :return: Permutation indexes to go from input to output format.
  123. """
  124. output_indexes = []
  125. for output_name, output_spec in output_format.layout.items():
  126. if output_name not in input_format.layout:
  127. raise KeyError(f"Requested item '{output_name}' was not found among input format spec. Present items are: {tuple(input_format.layout.keys())}")
  128. input_spec = input_format.layout[output_name]
  129. if input_spec.length != output_spec.length:
  130. raise RuntimeError(
  131. f"Length of the output must match in input and output format. "
  132. f"Input spec size is {input_spec.length} for key '{output_name}' and output spec size is {output_spec.length}."
  133. )
  134. indexes = input_format.indexes[output_name]
  135. output_indexes.extend(indexes)
  136. return output_indexes
Discard
1
2
3
  1. from .detection_adapter import DetectionOutputAdapter
  2. __all__ = ["DetectionOutputAdapter"]
Discard
@@ -4,8 +4,8 @@ from typing import Tuple, Union, Callable
 import torch
 import torch
 from torch import nn, Tensor
 from torch import nn, Tensor
 
 
-from super_gradients.training.utils.bbox_formats import BoundingBoxFormat
-from super_gradients.training.utils.output_adapters.formats import ConcatenatedTensorFormat
+from super_gradients.training.datasets.data_formats.bbox_formats import BoundingBoxFormat
+from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat
 
 
 __all__ = ["DetectionOutputAdapter"]
 __all__ = ["DetectionOutputAdapter"]
 
 
@@ -72,8 +72,8 @@ class DetectionOutputAdapter(nn.Module):
     Adapter class for converting model's predictions for object detection to a desired format.
     Adapter class for converting model's predictions for object detection to a desired format.
     This adapter supports torch.jit tracing & scripting & onnx conversion.
     This adapter supports torch.jit tracing & scripting & onnx conversion.
 
 
-    >>> from super_gradients.training.utils.output_adapters.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
-    >>> from super_gradients.training.utils.bbox_formats import XYXYCoordinateFormat, NormalizedXYWHCoordinateFormat
+    >>> from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
+    >>> from super_gradients.training.datasets.data_formats.bbox_formats import XYXYCoordinateFormat, NormalizedXYWHCoordinateFormat
     >>>
     >>>
     >>> class CustomDetectionHead(nn.Module):
     >>> class CustomDetectionHead(nn.Module):
     >>>    num_classes: int = 123
     >>>    num_classes: int = 123
Discard
@@ -22,7 +22,6 @@ from super_gradients.training.transforms.transforms import (
     DetectionMixup,
     DetectionMixup,
     DetectionHSV,
     DetectionHSV,
     DetectionHorizontalFlip,
     DetectionHorizontalFlip,
-    DetectionTargetsFormat,
     DetectionPaddedRescale,
     DetectionPaddedRescale,
     DetectionTargetsFormatTransform,
     DetectionTargetsFormatTransform,
     Standardize,
     Standardize,
@@ -82,7 +81,6 @@ TRANSFORMS = {
     Transforms.DetectionHSV: DetectionHSV,
     Transforms.DetectionHSV: DetectionHSV,
     Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
     Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
     Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
     Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
-    Transforms.DetectionTargetsFormat: DetectionTargetsFormat,
     Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
     Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
     Transforms.RandomResizedCropAndInterpolation: RandomResizedCropAndInterpolation,
     Transforms.RandomResizedCropAndInterpolation: RandomResizedCropAndInterpolation,
     Transforms.RandAugmentTransform: rand_augment_transform,
     Transforms.RandAugmentTransform: rand_augment_transform,
Discard
@@ -9,7 +9,12 @@ from torchvision import transforms as transforms
 import numpy as np
 import numpy as np
 import cv2
 import cv2
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.training.utils.detection_utils import get_mosaic_coordinate, adjust_box_anns, xyxy2cxcywh, cxcywh2xyxy, DetectionTargetsFormat
+from super_gradients.common.decorators.factory_decorator import resolve_param
+from super_gradients.common.factories.data_formats_factory import ConcatenatedTensorFormatFactory
+from super_gradients.training.utils.detection_utils import get_mosaic_coordinate, adjust_box_anns, xyxy2cxcywh, cxcywh2xyxy
+from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter
+from super_gradients.training.datasets.data_formats.formats import filter_on_bboxes, ConcatenatedTensorFormat
+from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_CXCYWH
 
 
 image_resample = Image.BILINEAR
 image_resample = Image.BILINEAR
 mask_resample = Image.NEAREST
 mask_resample = Image.NEAREST
@@ -757,88 +762,60 @@ class DetectionTargetsFormatTransform(DetectionTransform):
     """
     """
     Detection targets format transform
     Detection targets format transform
 
 
-    Converts targets in input_format to output_format.
+    Convert targets in input_format to output_format, filter small bboxes and pad targets.
     Attributes:
     Attributes:
-        input_format: DetectionTargetsFormat: input target format
-        output_format: DetectionTargetsFormat: output target format
-        min_bbox_edge_size: int: bboxes with edge size lower then this values will be removed.
-        max_targets: int: max objects in single image, padding target to this size.
+        image_shape:        Shape of the images to transform.
+        input_format:       Format of the input targets. For instance [xmin, ymin, xmax, ymax, cls_id] refers to XYXY_LABEL
+        output_format:      Format of the output targets. For instance [xmin, ymin, xmax, ymax, cls_id] refers to XYXY_LABEL
+        min_bbox_edge_size: bboxes with edge size lower then this values will be removed.
+        max_targets:        Max objects in single image, padding target to this size.
     """
     """
 
 
+    @resolve_param("input_format", ConcatenatedTensorFormatFactory())
+    @resolve_param("output_format", ConcatenatedTensorFormatFactory())
     def __init__(
     def __init__(
         self,
         self,
-        input_format: DetectionTargetsFormat = DetectionTargetsFormat.XYXY_LABEL,
-        output_format: DetectionTargetsFormat = DetectionTargetsFormat.LABEL_CXCYWH,
+        image_shape: tuple,
+        input_format: ConcatenatedTensorFormat = XYXY_LABEL,
+        output_format: ConcatenatedTensorFormat = LABEL_CXCYWH,
         min_bbox_edge_size: float = 1,
         min_bbox_edge_size: float = 1,
         max_targets: int = 120,
         max_targets: int = 120,
     ):
     ):
         super(DetectionTargetsFormatTransform, self).__init__()
         super(DetectionTargetsFormatTransform, self).__init__()
         self.input_format = input_format
         self.input_format = input_format
         self.output_format = output_format
         self.output_format = output_format
-        self.min_bbox_edge_size = min_bbox_edge_size
         self.max_targets = max_targets
         self.max_targets = max_targets
+        self.min_bbox_edge_size = min_bbox_edge_size / max(image_shape) if output_format.bboxes_format.format.normalized else min_bbox_edge_size
+        self.targets_format_converter = ConcatenatedTensorFormatConverter(input_format=input_format, output_format=output_format, image_shape=image_shape)
 
 
-    def __call__(self, sample):
-        normalized_input = "NORMALIZED" in self.input_format.value
-        normalized_output = "NORMALIZED" in self.output_format.value
-        normalize = not normalized_input and normalized_output
-        denormalize = normalized_input and not normalized_output
-
-        label_first_in_input = self.input_format.value.split("_")[0] == "LABEL"
-        label_first_in_output = self.output_format.value.split("_")[0] == "LABEL"
-
-        input_xyxy_format = "XYXY" in self.input_format.value
-        output_xyxy_format = "XYXY" in self.output_format.value
-        convert2xyxy = not input_xyxy_format and output_xyxy_format
-        convert2cxcy = input_xyxy_format and not output_xyxy_format
-
-        image, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target")
-
-        _, h, w = image.shape
-
-        def _format_target(targets_in):
-            if label_first_in_input:
-                labels, boxes = targets_in[:, 0], targets_in[:, 1:]
-            else:
-                boxes, labels = targets_in[:, :4], targets_in[:, 4]
-
-            if convert2cxcy:
-                boxes = xyxy2cxcywh(boxes)
-            elif convert2xyxy:
-                boxes = cxcywh2xyxy(boxes)
-
-            if normalize:
-                boxes[:, 0] = boxes[:, 0] / w
-                boxes[:, 1] = boxes[:, 1] / h
-                boxes[:, 2] = boxes[:, 2] / w
-                boxes[:, 3] = boxes[:, 3] / h
-
-            elif denormalize:
-                boxes[:, 0] = boxes[:, 0] * w
-                boxes[:, 1] = boxes[:, 1] * h
-                boxes[:, 2] = boxes[:, 2] * w
-                boxes[:, 3] = boxes[:, 3] * h
-
-            min_bbox_edge_size = self.min_bbox_edge_size / max(w, h) if normalized_output else self.min_bbox_edge_size
+    def __call__(self, sample: dict) -> dict:
+        sample["target"] = self.apply_on_targets(sample["target"])
+        if "crowd_target" in sample.keys():
+            sample["crowd_target"] = self.apply_on_targets(sample["crowd_target"])
+        return sample
 
 
-            cxcywh_boxes = boxes if not output_xyxy_format else xyxy2cxcywh(boxes.copy())
+    def apply_on_targets(self, targets: np.ndarray) -> np.ndarray:
+        """Convert targets in input_format to output_format, filter small bboxes and pad targets"""
+        targets = self.targets_format_converter(targets)
+        targets = self.filter_small_bboxes(targets)
+        targets = self.pad_targets(targets)
+        return targets
 
 
-            mask_b = np.minimum(cxcywh_boxes[:, 2], cxcywh_boxes[:, 3]) > min_bbox_edge_size
-            boxes_t = boxes[mask_b]
-            labels_t = labels[mask_b]
+    def filter_small_bboxes(self, targets: np.ndarray) -> np.ndarray:
+        """Filter bboxes smaller than specified threshold."""
 
 
-            labels_t = np.expand_dims(labels_t, 1)
-            targets_t = np.hstack((labels_t, boxes_t)) if label_first_in_output else np.hstack((boxes_t, labels_t))
-            padded_targets = np.zeros((self.max_targets, 5))
-            padded_targets[range(len(targets_t))[: self.max_targets]] = targets_t[: self.max_targets]
-            padded_targets = np.ascontiguousarray(padded_targets, dtype=np.float32)
+        def _is_big_enough(bboxes: np.ndarray) -> np.ndarray:
+            return np.minimum(bboxes[:, 2], bboxes[:, 3]) > self.min_bbox_edge_size
 
 
-            return padded_targets
+        targets = filter_on_bboxes(fn=_is_big_enough, tensor=targets, tensor_format=self.output_format)
+        return targets
 
 
-        sample["target"] = _format_target(targets)
-        if crowd_targets is not None:
-            sample["crowd_target"] = _format_target(crowd_targets)
-        return sample
+    def pad_targets(self, targets: np.ndarray) -> np.ndarray:
+        """Pad targets."""
+        padded_targets = np.zeros((self.max_targets, targets.shape[-1]))
+        padded_targets[range(len(targets))[: self.max_targets]] = targets[: self.max_targets]
+        padded_targets = np.ascontiguousarray(padded_targets, dtype=np.float32)
+        return padded_targets
 
 
 
 
 def get_aug_params(value: Union[tuple, float], center: float = 0):
 def get_aug_params(value: Union[tuple, float], center: float = 0):
Discard
1
2
3
4
  1. from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
  2. from .detection_adapter import DetectionOutputAdapter
  3. __all__ = ["DetectionOutputAdapter", "TensorSliceItem", "ConcatenatedTensorFormat", "BoundingBoxesTensorSliceItem"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. import collections
  2. from typing import Tuple, Union, List, Mapping
  3. from super_gradients.training.utils.bbox_formats import BoundingBoxFormat
  4. class DetectionOutputFormat:
  5. pass
  6. class TensorSliceItem:
  7. length: int
  8. name: str
  9. def __init__(self, name: str, length: int):
  10. self.name = name
  11. self.length = length
  12. def __repr__(self):
  13. return f"name={self.name} length={self.length}"
  14. class BoundingBoxesTensorSliceItem(TensorSliceItem):
  15. format: BoundingBoxFormat
  16. def __init__(self, name: str, format: BoundingBoxFormat):
  17. super().__init__(name, length=format.get_num_parameters())
  18. self.format = format
  19. def __repr__(self):
  20. return f"name={self.name} length={self.length} format={self.format}"
  21. class ConcatenatedTensorFormat(DetectionOutputFormat):
  22. """
  23. Define the output format that return a single tensor of shape [N,M] (N - number of detections,
  24. M - sum of bbox attributes) that is a concatenated from bbox coordinates and other fields.
  25. A layout defines the order of concatenated tensors. For instance:
  26. - layout: (bboxes, scores, labels) gives a Tensor that is product of torch.cat([bboxes, scores, labels], dim=1)
  27. - layout: (labels, bboxes) produce a Tensor from torch.cat([labels, bboxes], dim=1)
  28. """
  29. layout: Mapping[str, TensorSliceItem]
  30. locations: Mapping[str, Tuple[int, int]]
  31. indexes: Mapping[str, List[int]]
  32. num_channels: int
  33. @property
  34. def bboxes_format(self) -> BoundingBoxesTensorSliceItem:
  35. bbox_items = [x for x in self.layout.values() if isinstance(x, BoundingBoxesTensorSliceItem)]
  36. return bbox_items[0]
  37. def __init__(self, layout: Union[List[TensorSliceItem], Tuple[TensorSliceItem, ...]]):
  38. bbox_items = [x for x in layout if isinstance(x, BoundingBoxesTensorSliceItem)]
  39. if len(bbox_items) != 1:
  40. raise RuntimeError("Number of bounding box items must be strictly equal to 1")
  41. _layout = []
  42. _locations = []
  43. _indexes = []
  44. offset = 0
  45. for item in layout:
  46. location_indexes = list(range(offset, offset + item.length))
  47. location_slice = offset, offset + item.length
  48. _layout.append((item.name, item))
  49. _locations.append((item.name, location_slice))
  50. _indexes.append((item.name, location_indexes))
  51. offset += item.length
  52. self.layout = collections.OrderedDict(_layout)
  53. self.locations = collections.OrderedDict(_locations)
  54. self.indexes = collections.OrderedDict(_indexes)
  55. self.num_channels = offset
  56. def __repr__(self):
  57. return str(self.layout)
Discard
@@ -6,8 +6,9 @@ import unittest
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
+
 from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory
 from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory
-from super_gradients.training.utils.bbox_formats import (
+from super_gradients.training.datasets.data_formats.bbox_formats import (
     CXCYWHCoordinateFormat,
     CXCYWHCoordinateFormat,
     NormalizedXYXYCoordinateFormat,
     NormalizedXYXYCoordinateFormat,
     NormalizedXYWHCoordinateFormat,
     NormalizedXYWHCoordinateFormat,
@@ -19,21 +20,21 @@ from super_gradients.training.utils.bbox_formats import (
     BBOX_FORMATS,
     BBOX_FORMATS,
     BoundingBoxFormat,
     BoundingBoxFormat,
 )
 )
-from super_gradients.training.utils.bbox_formats.normalized_cxcywh import (
+from super_gradients.training.datasets.data_formats.bbox_formats.normalized_cxcywh import (
     normalized_cxcywh_to_xyxy_inplace,
     normalized_cxcywh_to_xyxy_inplace,
     xyxy_to_normalized_cxcywh_inplace,
     xyxy_to_normalized_cxcywh_inplace,
     xyxy_to_normalized_cxcywh,
     xyxy_to_normalized_cxcywh,
     normalized_cxcywh_to_xyxy,
     normalized_cxcywh_to_xyxy,
 )
 )
-from super_gradients.training.utils.bbox_formats.normalized_xywh import (
+from super_gradients.training.datasets.data_formats.bbox_formats.normalized_xywh import (
     xyxy_to_normalized_xywh_inplace,
     xyxy_to_normalized_xywh_inplace,
     xyxy_to_normalized_xywh,
     xyxy_to_normalized_xywh,
     normalized_xywh_to_xyxy_inplace,
     normalized_xywh_to_xyxy_inplace,
     normalized_xywh_to_xyxy,
     normalized_xywh_to_xyxy,
 )
 )
-from super_gradients.training.utils.bbox_formats.xywh import xyxy_to_xywh, xywh_to_xyxy, xywh_to_xyxy_inplace, xyxy_to_xywh_inplace
-from super_gradients.training.utils.bbox_formats.yxyx import xyxy_to_yxyx, xyxy_to_yxyx_inplace
-from super_gradients.training.utils.output_adapters.detection_adapter import ConvertBoundingBoxes
+from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xyxy_to_xywh, xywh_to_xyxy, xywh_to_xyxy_inplace, xyxy_to_xywh_inplace
+from super_gradients.training.datasets.data_formats.bbox_formats.yxyx import xyxy_to_yxyx, xyxy_to_yxyx_inplace
+from super_gradients.training.datasets.data_formats.output_adapters.detection_adapter import ConvertBoundingBoxes
 
 
 
 
 class BBoxFormatsTest(unittest.TestCase):
 class BBoxFormatsTest(unittest.TestCase):
@@ -255,7 +256,7 @@ class BBoxFormatsTest(unittest.TestCase):
                 with tempfile.TemporaryDirectory() as tmpdirname:
                 with tempfile.TemporaryDirectory() as tmpdirname:
                     adapter_fname = os.path.join(tmpdirname, "adapter.onnx")
                     adapter_fname = os.path.join(tmpdirname, "adapter.onnx")
                     # Just test that export works, we test the correctness in the detection_output_adapter_test.py
                     # Just test that export works, we test the correctness in the detection_output_adapter_test.py
-                    torch.onnx.export(module, gt_bboxes.clone(), adapter_fname)
+                    torch.onnx.export(module, gt_bboxes.clone(), adapter_fname, opset_version=11)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard
@@ -7,8 +7,13 @@ import onnx
 import onnxruntime as ort
 import onnxruntime as ort
 import torch.jit
 import torch.jit
 
 
-from super_gradients.training.utils.bbox_formats import NormalizedXYWHCoordinateFormat, CXCYWHCoordinateFormat, YXYXCoordinateFormat
-from super_gradients.training.utils.output_adapters import DetectionOutputAdapter, ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
+from super_gradients.training.datasets.data_formats.bbox_formats import NormalizedXYWHCoordinateFormat, CXCYWHCoordinateFormat, YXYXCoordinateFormat
+from super_gradients.training.datasets.data_formats.output_adapters.detection_adapter import DetectionOutputAdapter
+from super_gradients.training.datasets.data_formats import (
+    ConcatenatedTensorFormat,
+    BoundingBoxesTensorSliceItem,
+    TensorSliceItem,
+)
 
 
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
     layout=(
     layout=(
@@ -119,7 +124,7 @@ class TestDetectionOutputAdapter(unittest.TestCase):
 
 
             with tempfile.TemporaryDirectory() as tmpdirname:
             with tempfile.TemporaryDirectory() as tmpdirname:
                 adapter_fname = os.path.join(tmpdirname, "adapter.onnx")
                 adapter_fname = os.path.join(tmpdirname, "adapter.onnx")
-                torch.onnx.export(adapter, inp, f=adapter_fname, input_names=["predictions"], output_names=["output_predictions"])
+                torch.onnx.export(adapter, inp, f=adapter_fname, input_names=["predictions"], output_names=["output_predictions"], opset_version=11)
 
 
                 onnx_model = onnx.load(adapter_fname)
                 onnx_model = onnx.load(adapter_fname)
                 onnx.checker.check_model(onnx_model)
                 onnx.checker.check_model(onnx_model)
Discard
@@ -2,7 +2,14 @@ import numpy as np
 import unittest
 import unittest
 
 
 from super_gradients.training.transforms.transforms import DetectionTargetsFormatTransform
 from super_gradients.training.transforms.transforms import DetectionTargetsFormatTransform
-from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
+
+from super_gradients.training.datasets.data_formats.default_formats import (
+    XYXY_LABEL,
+    LABEL_XYXY,
+    LABEL_CXCYWH,
+    LABEL_NORMALIZED_XYXY,
+    LABEL_NORMALIZED_CXCYWH,
+)
 
 
 
 
 class DetectionTargetsTransformTest(unittest.TestCase):
 class DetectionTargetsTransformTest(unittest.TestCase):
@@ -12,110 +19,115 @@ class DetectionTargetsTransformTest(unittest.TestCase):
     def test_label_first_2_label_last(self):
     def test_label_first_2_label_last(self):
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         output = np.array([[50, 10, 20, 30, 40]], dtype=np.float32)
         output = np.array([[50, 10, 20, 30, 40]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.XYXY_LABEL,
-                                                    output_format=DetectionTargetsFormat.LABEL_XYXY)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
-        self.assertTrue(np.array_equal(transform(sample)["target"], output))
+
+        transform = DetectionTargetsFormatTransform(image_shape=self.image.shape[1:], max_targets=1, input_format=XYXY_LABEL, output_format=LABEL_XYXY)
+        t_output = transform(sample)["target"]
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_xyxy_2_normalized_xyxy(self):
     def test_xyxy_2_normalized_xyxy(self):
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_XYXY,
-                                                    output_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(
+            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_NORMALIZED_XYXY
+        )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.array_equal(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_xyxy_2_cxcywh(self):
     def test_xyxy_2_cxcywh(self):
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_XYXY,
-                                                    output_format=DetectionTargetsFormat.LABEL_CXCYWH)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_CXCYWH)
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.array_equal(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_xyxy_2_normalized_cxcywh(self):
     def test_xyxy_2_normalized_cxcywh(self):
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         output = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
         output = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_XYXY,
-                                                    output_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(
+            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_NORMALIZED_CXCYWH
+        )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.array_equal(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_normalized_xyxy_2_cxcywh(self):
     def test_normalized_xyxy_2_cxcywh(self):
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         input = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         input = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY,
-                                                    output_format=DetectionTargetsFormat.LABEL_CXCYWH)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(
+            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_XYXY, output_format=LABEL_CXCYWH
+        )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.allclose(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_normalized_xyxy_2_normalized_cxcywh(self):
     def test_normalized_xyxy_2_normalized_cxcywh(self):
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         input = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         input = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         output = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
         output = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY,
-                                                    output_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(
+            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_XYXY, output_format=LABEL_NORMALIZED_CXCYWH
+        )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.allclose(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_cxcywh_2_xyxy(self):
     def test_cxcywh_2_xyxy(self):
         output = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         output = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_CXCYWH,
-                                                    output_format=DetectionTargetsFormat.LABEL_XYXY)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_CXCYWH, output_format=LABEL_XYXY)
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.array_equal(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_cxcywh_2_normalized_xyxy(self):
     def test_cxcywh_2_normalized_xyxy(self):
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_CXCYWH,
-                                                    output_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(
+            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_CXCYWH, output_format=LABEL_NORMALIZED_XYXY
+        )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.array_equal(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_normalized_cxcywh_2_xyxy(self):
     def test_normalized_cxcywh_2_xyxy(self):
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         input = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
         input = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
         output = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
         output = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH,
-                                                    output_format=DetectionTargetsFormat.LABEL_XYXY)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(
+            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_CXCYWH, output_format=LABEL_XYXY
+        )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.allclose(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
     def test_normalized_cxcywh_2_normalized_xyxy(self):
     def test_normalized_cxcywh_2_normalized_xyxy(self):
         _, h, w = self.image.shape
         _, h, w = self.image.shape
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         input = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
         input = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
-        transform = DetectionTargetsFormatTransform(max_targets=1,
-                                                    input_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH,
-                                                    output_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
+
+        transform = DetectionTargetsFormatTransform(
+            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_CXCYWH, output_format=LABEL_NORMALIZED_XYXY
+        )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
-        self.assertTrue(np.allclose(output, t_output))
+        self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard