Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#622 Hotfix/sg 000 transform backward compatibility

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-transform_backward_compatibility
@@ -35,7 +35,7 @@ train_dataset_params:
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         max_targets: 120
         max_targets: 120
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
-        image_shape: ${dataset_params.train_dataset_params.input_dim}
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
         output_format: LABEL_CXCYWH
         output_format: LABEL_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
@@ -69,8 +69,8 @@ val_dataset_params:
   - DetectionPaddedRescale:
   - DetectionPaddedRescale:
       input_dim: ${dataset_params.val_dataset_params.input_dim}
       input_dim: ${dataset_params.val_dataset_params.input_dim}
   - DetectionTargetsFormatTransform:
   - DetectionTargetsFormatTransform:
-      image_shape: ${dataset_params.val_dataset_params.input_dim}
       max_targets: 50
       max_targets: 50
+      input_dim: ${dataset_params.val_dataset_params.input_dim}
       output_format: LABEL_CXCYWH
       output_format: LABEL_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
   class_inclusion_list:
   class_inclusion_list:
Discard
@@ -30,7 +30,7 @@ train_dataset_params:
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         input_dim: ${dataset_params.train_dataset_params.input_dim}
         max_targets: 120
         max_targets: 120
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
-        image_shape: ${dataset_params.train_dataset_params.input_dim}
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
         max_targets: 50
         max_targets: 50
         output_format: LABEL_NORMALIZED_CXCYWH
         output_format: LABEL_NORMALIZED_CXCYWH
 
 
@@ -65,7 +65,7 @@ val_dataset_params:
     - DetectionPaddedRescale:
     - DetectionPaddedRescale:
         input_dim: ${dataset_params.val_dataset_params.input_dim}
         input_dim: ${dataset_params.val_dataset_params.input_dim}
     - DetectionTargetsFormatTransform:
     - DetectionTargetsFormatTransform:
-        image_shape: ${dataset_params.val_dataset_params.input_dim}
+        input_dim: ${dataset_params.val_dataset_params.input_dim}
         max_targets: 50
         max_targets: 50
         output_format: LABEL_NORMALIZED_CXCYWH
         output_format: LABEL_NORMALIZED_CXCYWH
   tight_box_rotation: False
   tight_box_rotation: False
Discard
@@ -4,13 +4,12 @@ train_dataset_params:
   cache: False
   cache: False
   cache_dir:
   cache_dir:
   transforms:
   transforms:
-  - DetectionPaddedRescale:
-      input_dim: ${dataset_params.train_dataset_params.input_dim}
-  - DetectionTargetsFormatTransform:
-      max_targets: 50
-      output_format:
-        _target_: super_gradients.training.utils.detection_utils.DetectionTargetsFormat # targets format
-        value: LABEL_CXCYWH
+    - DetectionPaddedRescale:
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
+    - DetectionTargetsFormatTransform:
+        max_targets: 50
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
+        output_format: LABEL_CXCYWH
   class_inclusion_list:
   class_inclusion_list:
   max_num_samples:
   max_num_samples:
   download: True
   download: True
@@ -21,13 +20,12 @@ val_dataset_params:
   cache: False
   cache: False
   cache_dir:
   cache_dir:
   transforms:
   transforms:
-  - DetectionPaddedRescale:
-      input_dim: ${dataset_params.train_dataset_params.input_dim}
-  - DetectionTargetsFormatTransform:
-      max_targets: 50
-      output_format:
-        _target_: super_gradients.training.utils.detection_utils.DetectionTargetsFormat # targets format
-        value: LABEL_CXCYWH
+    - DetectionPaddedRescale:
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
+    - DetectionTargetsFormatTransform:
+        max_targets: 50
+        input_dim: ${dataset_params.train_dataset_params.input_dim}
+        output_format: LABEL_CXCYWH
   images_sub_directory: images/test2007/
   images_sub_directory: images/test2007/
   class_inclusion_list:
   class_inclusion_list:
   max_num_samples:
   max_num_samples:
Discard
@@ -41,6 +41,20 @@ class ConcatenatedTensorFormat(DetectionOutputFormat):
     A layout defines the order of concatenated tensors. For instance:
     A layout defines the order of concatenated tensors. For instance:
     - layout: (bboxes, scores, labels) gives a Tensor that is product of torch.cat([bboxes, scores, labels], dim=1)
     - layout: (bboxes, scores, labels) gives a Tensor that is product of torch.cat([bboxes, scores, labels], dim=1)
     - layout: (labels, bboxes) produce a Tensor from torch.cat([labels, bboxes], dim=1)
     - layout: (labels, bboxes) produce a Tensor from torch.cat([labels, bboxes], dim=1)
+
+
+    >>> from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
+    >>> from super_gradients.training.datasets.data_formats.bbox_formats import XYXYCoordinateFormat, NormalizedXYWHCoordinateFormat
+    >>>
+    >>> custom_format = ConcatenatedTensorFormat(
+    >>>     layout=(
+    >>>         BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
+    >>>         TensorSliceItem(name="label", length=1),
+    >>>         TensorSliceItem(name="distance", length=1),
+    >>>         TensorSliceItem(name="attributes", length=4),
+    >>>     )
+    >>> )
+
     """
     """
 
 
     layout: Mapping[str, TensorSliceItem]
     layout: Mapping[str, TensorSliceItem]
Discard
@@ -11,7 +11,7 @@ import cv2
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.data_formats_factory import ConcatenatedTensorFormatFactory
 from super_gradients.common.factories.data_formats_factory import ConcatenatedTensorFormatFactory
-from super_gradients.training.utils.detection_utils import get_mosaic_coordinate, adjust_box_anns, xyxy2cxcywh, cxcywh2xyxy
+from super_gradients.training.utils.detection_utils import get_mosaic_coordinate, adjust_box_anns, xyxy2cxcywh, cxcywh2xyxy, DetectionTargetsFormat
 from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter
 from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter
 from super_gradients.training.datasets.data_formats.formats import filter_on_bboxes, ConcatenatedTensorFormat
 from super_gradients.training.datasets.data_formats.formats import filter_on_bboxes, ConcatenatedTensorFormat
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_CXCYWH
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_CXCYWH
@@ -764,8 +764,8 @@ class DetectionTargetsFormatTransform(DetectionTransform):
 
 
     Convert targets in input_format to output_format, filter small bboxes and pad targets.
     Convert targets in input_format to output_format, filter small bboxes and pad targets.
     Attributes:
     Attributes:
-        image_shape:        Shape of the images to transform.
-        input_format:       Format of the input targets. For instance [xmin, ymin, xmax, ymax, cls_id] refers to XYXY_LABEL
+        input_dim:          Shape of the images to transform.
+        input_format:       Format of the input targets. For instance [xmin, ymin, xmax, ymax, cls_id] refers to XYXY_LABEL.
         output_format:      Format of the output targets. For instance [xmin, ymin, xmax, ymax, cls_id] refers to XYXY_LABEL
         output_format:      Format of the output targets. For instance [xmin, ymin, xmax, ymax, cls_id] refers to XYXY_LABEL
         min_bbox_edge_size: bboxes with edge size lower then this values will be removed.
         min_bbox_edge_size: bboxes with edge size lower then this values will be removed.
         max_targets:        Max objects in single image, padding target to this size.
         max_targets:        Max objects in single image, padding target to this size.
@@ -775,20 +775,43 @@ class DetectionTargetsFormatTransform(DetectionTransform):
     @resolve_param("output_format", ConcatenatedTensorFormatFactory())
     @resolve_param("output_format", ConcatenatedTensorFormatFactory())
     def __init__(
     def __init__(
         self,
         self,
-        image_shape: tuple,
+        input_dim: Optional[tuple] = None,
         input_format: ConcatenatedTensorFormat = XYXY_LABEL,
         input_format: ConcatenatedTensorFormat = XYXY_LABEL,
         output_format: ConcatenatedTensorFormat = LABEL_CXCYWH,
         output_format: ConcatenatedTensorFormat = LABEL_CXCYWH,
         min_bbox_edge_size: float = 1,
         min_bbox_edge_size: float = 1,
         max_targets: int = 120,
         max_targets: int = 120,
     ):
     ):
         super(DetectionTargetsFormatTransform, self).__init__()
         super(DetectionTargetsFormatTransform, self).__init__()
+        if isinstance(input_format, DetectionTargetsFormat) or isinstance(output_format, DetectionTargetsFormat):
+            raise TypeError(
+                "DetectionTargetsFormat is not supported for input_format and output_format starting from super_gradients==3.0.7.\n"
+                "You can either:\n"
+                "\t - use builtin format among super_gradients.training.datasets.data_formats.default_formats.<FORMAT_NAME> (e.g. XYXY_LABEL, CXCY_LABEL, ..)\n"
+                "\t - define your custom format using super_gradients.training.datasets.data_formats.formats.ConcatenatedTensorFormat\n"
+            )
         self.input_format = input_format
         self.input_format = input_format
         self.output_format = output_format
         self.output_format = output_format
         self.max_targets = max_targets
         self.max_targets = max_targets
-        self.min_bbox_edge_size = min_bbox_edge_size / max(image_shape) if output_format.bboxes_format.format.normalized else min_bbox_edge_size
-        self.targets_format_converter = ConcatenatedTensorFormatConverter(input_format=input_format, output_format=output_format, image_shape=image_shape)
+        self.min_bbox_edge_size = min_bbox_edge_size
+        self.input_dim = None
+
+        if input_dim is not None:
+            self._setup_input_dim_related_params(input_dim)
+
+    def _setup_input_dim_related_params(self, input_dim: tuple):
+        """Setup all the parameters that are related to input_dim."""
+        self.input_dim = input_dim
+        self.min_bbox_edge_size = self.min_bbox_edge_size / max(input_dim) if self.output_format.bboxes_format.format.normalized else self.min_bbox_edge_size
+        self.targets_format_converter = ConcatenatedTensorFormatConverter(
+            input_format=self.input_format, output_format=self.output_format, image_shape=input_dim
+        )
 
 
     def __call__(self, sample: dict) -> dict:
     def __call__(self, sample: dict) -> dict:
+
+        # if self.input_dim not set yet, it will be set with first batch
+        if self.input_dim is None:
+            self._setup_input_dim_related_params(input_dim=sample["image"].shape[1:])
+
         sample["target"] = self.apply_on_targets(sample["target"])
         sample["target"] = self.apply_on_targets(sample["target"])
         if "crowd_target" in sample.keys():
         if "crowd_target" in sample.keys():
             sample["crowd_target"] = self.apply_on_targets(sample["crowd_target"])
             sample["crowd_target"] = self.apply_on_targets(sample["crowd_target"])
Discard
@@ -13,7 +13,7 @@ import super_gradients
 from super_gradients.training.dataloaders.dataloaders import _process_dataset_params
 from super_gradients.training.dataloaders.dataloaders import _process_dataset_params
 from super_gradients.training.datasets import PascalVOCDetectionDataset, COCODetectionDataset
 from super_gradients.training.datasets import PascalVOCDetectionDataset, COCODetectionDataset
 from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
 from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
-from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
+from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
 from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
 from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
 from super_gradients.training.utils.hydra_utils import normalize_path
 from super_gradients.training.utils.hydra_utils import normalize_path
 
 
@@ -35,7 +35,7 @@ class DatasetIntegrationTest(unittest.TestCase):
         transforms = [
         transforms = [
             DetectionMosaic(input_dim=(640, 640), prob=0.8),
             DetectionMosaic(input_dim=(640, 640), prob=0.8),
             DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
             DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
-            DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.XYXY_LABEL),
+            DetectionTargetsFormatTransform(input_dim=(640, 640), output_format=XYXY_LABEL),
         ]
         ]
 
 
         self.test_dir = tempfile.TemporaryDirectory().name
         self.test_dir = tempfile.TemporaryDirectory().name
Discard
@@ -21,7 +21,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         output = np.array([[50, 10, 20, 30, 40]], dtype=np.float32)
         output = np.array([[50, 10, 20, 30, 40]], dtype=np.float32)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
-        transform = DetectionTargetsFormatTransform(image_shape=self.image.shape[1:], max_targets=1, input_format=XYXY_LABEL, output_format=LABEL_XYXY)
+        transform = DetectionTargetsFormatTransform(input_dim=self.image.shape[1:], max_targets=1, input_format=XYXY_LABEL, output_format=LABEL_XYXY)
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
@@ -31,9 +31,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
-        transform = DetectionTargetsFormatTransform(
-            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_NORMALIZED_XYXY
-        )
+        transform = DetectionTargetsFormatTransform(input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_NORMALIZED_XYXY)
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
@@ -43,7 +41,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
-        transform = DetectionTargetsFormatTransform(image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_CXCYWH)
+        transform = DetectionTargetsFormatTransform(input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_CXCYWH)
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
@@ -54,7 +52,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
         transform = DetectionTargetsFormatTransform(
         transform = DetectionTargetsFormatTransform(
-            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_NORMALIZED_CXCYWH
+            input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_XYXY, output_format=LABEL_NORMALIZED_CXCYWH
         )
         )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
@@ -66,7 +64,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
         transform = DetectionTargetsFormatTransform(
         transform = DetectionTargetsFormatTransform(
-            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_XYXY, output_format=LABEL_CXCYWH
+            input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_XYXY, output_format=LABEL_CXCYWH
         )
         )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
@@ -78,7 +76,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
         transform = DetectionTargetsFormatTransform(
         transform = DetectionTargetsFormatTransform(
-            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_XYXY, output_format=LABEL_NORMALIZED_CXCYWH
+            input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_XYXY, output_format=LABEL_NORMALIZED_CXCYWH
         )
         )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
@@ -88,7 +86,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
-        transform = DetectionTargetsFormatTransform(image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_CXCYWH, output_format=LABEL_XYXY)
+        transform = DetectionTargetsFormatTransform(input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_CXCYWH, output_format=LABEL_XYXY)
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
 
 
@@ -99,7 +97,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
         transform = DetectionTargetsFormatTransform(
         transform = DetectionTargetsFormatTransform(
-            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_CXCYWH, output_format=LABEL_NORMALIZED_XYXY
+            input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_CXCYWH, output_format=LABEL_NORMALIZED_XYXY
         )
         )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
@@ -111,7 +109,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
         transform = DetectionTargetsFormatTransform(
         transform = DetectionTargetsFormatTransform(
-            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_CXCYWH, output_format=LABEL_XYXY
+            input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_CXCYWH, output_format=LABEL_XYXY
         )
         )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
@@ -123,7 +121,7 @@ class DetectionTargetsTransformTest(unittest.TestCase):
         sample = {"image": self.image, "target": input}
         sample = {"image": self.image, "target": input}
 
 
         transform = DetectionTargetsFormatTransform(
         transform = DetectionTargetsFormatTransform(
-            image_shape=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_CXCYWH, output_format=LABEL_NORMALIZED_XYXY
+            input_dim=self.image.shape[1:], max_targets=1, input_format=LABEL_NORMALIZED_CXCYWH, output_format=LABEL_NORMALIZED_XYXY
         )
         )
         t_output = transform(sample)["target"]
         t_output = transform(sample)["target"]
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
         self.assertTrue(np.allclose(output, t_output, atol=1e-6))
Discard