Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#559 Feature/sg 468 detection transform support for any number of channels

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-468_detection_transform_support_for_any_number_of_channels
@@ -374,14 +374,16 @@ class DetectionMosaic(DetectionTransform):
         input_dim: (tuple) input dimension.
         input_dim: (tuple) input dimension.
         prob: (float) probability of applying mosaic.
         prob: (float) probability of applying mosaic.
         enable_mosaic: (bool) whether to apply mosaic at all (regardless of prob) (default=True).
         enable_mosaic: (bool) whether to apply mosaic at all (regardless of prob) (default=True).
+        border_value: value for filling borders after applying transforms (default=114).
 
 
     """
     """
 
 
-    def __init__(self, input_dim: tuple, prob: float = 1.0, enable_mosaic: bool = True):
+    def __init__(self, input_dim: tuple, prob: float = 1.0, enable_mosaic: bool = True, border_value=114):
         super(DetectionMosaic, self).__init__(additional_samples_count=3)
         super(DetectionMosaic, self).__init__(additional_samples_count=3)
         self.prob = prob
         self.prob = prob
         self.input_dim = input_dim
         self.input_dim = input_dim
         self.enable_mosaic = enable_mosaic
         self.enable_mosaic = enable_mosaic
+        self.border_value = border_value
 
 
     def close(self):
     def close(self):
         self.additional_samples_count = 0
         self.additional_samples_count = 0
@@ -410,7 +412,7 @@ class DetectionMosaic(DetectionTransform):
                 # generate output mosaic image
                 # generate output mosaic image
                 (h, w, c) = img.shape[:3]
                 (h, w, c) = img.shape[:3]
                 if i_mosaic == 0:
                 if i_mosaic == 0:
-                    mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)
+                    mosaic_img = np.full((input_h * 2, input_w * 2, c), self.border_value, dtype=np.uint8)
 
 
                 # suffix l means large image, while s means small image in mosaic aug.
                 # suffix l means large image, while s means small image in mosaic aug.
                 (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(i_mosaic, xc, yc, w, h, input_h, input_w)
                 (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(i_mosaic, xc, yc, w, h, input_h, input_w)
@@ -488,10 +490,23 @@ class DetectionRandomAffine(DetectionTransform):
      area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.
      area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.
       Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)
       Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)
 
 
+     border_value: value for filling borders after applying transforms (default=114).
+
+
     """
     """
 
 
     def __init__(
     def __init__(
-        self, degrees=10, translate=0.1, scales=0.1, shear=10, target_size=(640, 640), filter_box_candidates: bool = False, wh_thr=2, ar_thr=20, area_thr=0.1
+        self,
+        degrees=10,
+        translate=0.1,
+        scales=0.1,
+        shear=10,
+        target_size=(640, 640),
+        filter_box_candidates: bool = False,
+        wh_thr=2,
+        ar_thr=20,
+        area_thr=0.1,
+        border_value=114,
     ):
     ):
         super(DetectionRandomAffine, self).__init__()
         super(DetectionRandomAffine, self).__init__()
         self.degrees = degrees
         self.degrees = degrees
@@ -504,6 +519,7 @@ class DetectionRandomAffine(DetectionTransform):
         self.wh_thr = wh_thr
         self.wh_thr = wh_thr
         self.ar_thr = ar_thr
         self.ar_thr = ar_thr
         self.area_thr = area_thr
         self.area_thr = area_thr
+        self.border_value = border_value
 
 
     def close(self):
     def close(self):
         self.enable = False
         self.enable = False
@@ -523,6 +539,7 @@ class DetectionRandomAffine(DetectionTransform):
                 wh_thr=self.wh_thr,
                 wh_thr=self.wh_thr,
                 area_thr=self.area_thr,
                 area_thr=self.area_thr,
                 ar_thr=self.ar_thr,
                 ar_thr=self.ar_thr,
+                border_value=self.border_value,
             )
             )
             sample["image"] = img
             sample["image"] = img
             sample["target"] = target
             sample["target"] = target
@@ -539,15 +556,18 @@ class DetectionMixup(DetectionTransform):
         prob: (float) probability of applying mixup.
         prob: (float) probability of applying mixup.
         enable_mixup: (bool) whether to apply mixup at all (regardless of prob) (default=True).
         enable_mixup: (bool) whether to apply mixup at all (regardless of prob) (default=True).
         flip_prob: (float) prbability to apply horizontal flip to the additional sample.
         flip_prob: (float) prbability to apply horizontal flip to the additional sample.
+        border_value: value for filling borders after applying transform (default=114).
+
     """
     """
 
 
-    def __init__(self, input_dim, mixup_scale, prob=1.0, enable_mixup=True, flip_prob=0.5):
+    def __init__(self, input_dim, mixup_scale, prob=1.0, enable_mixup=True, flip_prob=0.5, border_value=114):
         super(DetectionMixup, self).__init__(additional_samples_count=1, non_empty_targets=True)
         super(DetectionMixup, self).__init__(additional_samples_count=1, non_empty_targets=True)
         self.input_dim = input_dim
         self.input_dim = input_dim
         self.mixup_scale = mixup_scale
         self.mixup_scale = mixup_scale
         self.prob = prob
         self.prob = prob
         self.enable_mixup = enable_mixup
         self.enable_mixup = enable_mixup
         self.flip_prob = flip_prob
         self.flip_prob = flip_prob
+        self.border_value = border_value
 
 
     def close(self):
     def close(self):
         self.additional_samples_count = 0
         self.additional_samples_count = 0
@@ -567,9 +587,9 @@ class DetectionMixup(DetectionTransform):
             jit_factor = random.uniform(*self.mixup_scale)
             jit_factor = random.uniform(*self.mixup_scale)
 
 
             if len(img.shape) == 3:
             if len(img.shape) == 3:
-                cp_img = np.ones((self.input_dim[0], self.input_dim[1], 3), dtype=np.uint8) * 114
+                cp_img = np.ones((self.input_dim[0], self.input_dim[1], img.shape[2]), dtype=np.uint8) * self.border_value
             else:
             else:
-                cp_img = np.ones(self.input_dim, dtype=np.uint8) * 114
+                cp_img = np.ones(self.input_dim, dtype=np.uint8) * self.border_value
 
 
             cp_scale_ratio = min(self.input_dim[0] / img.shape[0], self.input_dim[1] / img.shape[1])
             cp_scale_ratio = min(self.input_dim[0] / img.shape[0], self.input_dim[1] / img.shape[1])
             resized_img = cv2.resize(
             resized_img = cv2.resize(
@@ -588,7 +608,12 @@ class DetectionMixup(DetectionTransform):
 
 
             origin_h, origin_w = cp_img.shape[:2]
             origin_h, origin_w = cp_img.shape[:2]
             target_h, target_w = origin_img.shape[:2]
             target_h, target_w = origin_img.shape[:2]
-            padded_img = np.zeros((max(origin_h, target_h), max(origin_w, target_w), 3), dtype=np.uint8)
+
+            if len(img.shape) == 3:
+                padded_img = np.zeros((max(origin_h, target_h), max(origin_w, target_w), img.shape[2]), dtype=np.uint8)
+            else:
+                padded_img = np.zeros((max(origin_h, target_h), max(origin_w, target_w)), dtype=np.uint8)
+
             padded_img[:origin_h, :origin_w] = cp_img
             padded_img[:origin_h, :origin_w] = cp_img
 
 
             x_offset, y_offset = 0, 0
             x_offset, y_offset = 0, 0
@@ -690,6 +715,15 @@ class DetectionHorizontalFlip(DetectionTransform):
 class DetectionHSV(DetectionTransform):
 class DetectionHSV(DetectionTransform):
     """
     """
     Detection HSV transform.
     Detection HSV transform.
+
+    Attributes:
+        prob: (float) probability to apply the transform.
+        hgain: (float) hue gain (default=0.5)
+        sgain: (float) saturation gain (default=0.5)
+        vgain: (float) value gain (default=0.5)
+        bgr_channels: (tuple) channel indices of the BGR channels- useful for images with >3 channels,
+         or when BGR channels are in different order. (default=(0,1,2)).
+
     """
     """
 
 
     def __init__(self, prob: float, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5, bgr_channels=(0, 1, 2)):
     def __init__(self, prob: float, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5, bgr_channels=(0, 1, 2)):
@@ -699,8 +733,20 @@ class DetectionHSV(DetectionTransform):
         self.sgain = sgain
         self.sgain = sgain
         self.vgain = vgain
         self.vgain = vgain
         self.bgr_channels = bgr_channels
         self.bgr_channels = bgr_channels
+        self._additional_channels_warned = False
 
 
     def __call__(self, sample: dict) -> dict:
     def __call__(self, sample: dict) -> dict:
+        if sample["image"].shape[2] < 3:
+            raise ValueError("HSV transform expects at least 3 channels, got: " + str(sample["image"].shape[2]))
+        if sample["image"].shape[2] > 3 and not self._additional_channels_warned:
+            logger.warning(
+                "HSV transform received image with "
+                + str(sample["image"].shape[2])
+                + " channels. HSV transform will only be applied on channels: "
+                + str(self.bgr_channels)
+                + "."
+            )
+            self._additional_channels_warned = True
         if random.random() < self.prob:
         if random.random() < self.prob:
             augment_hsv(sample["image"], self.hgain, self.sgain, self.vgain, self.bgr_channels)
             augment_hsv(sample["image"], self.hgain, self.sgain, self.vgain, self.bgr_channels)
         return sample
         return sample
@@ -937,6 +983,7 @@ def random_affine(
     wh_thr=2,
     wh_thr=2,
     ar_thr=20,
     ar_thr=20,
     area_thr=0.1,
     area_thr=0.1,
+    border_value=114,
 ):
 ):
     """
     """
     Performs random affine transform to img, targets
     Performs random affine transform to img, targets
@@ -962,13 +1009,16 @@ def random_affine(
 
 
     :param area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.
     :param area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.
       Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)
       Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)
+
+    :param border_value: value for filling borders after applying transforms (default=114).
+
     :return:            Image and Target with applied random affine
     :return:            Image and Target with applied random affine
     """
     """
 
 
     targets_seg = np.zeros((targets.shape[0], 0)) if targets_seg is None else targets_seg
     targets_seg = np.zeros((targets.shape[0], 0)) if targets_seg is None else targets_seg
     M, scale = get_affine_matrix(target_size, degrees, translate, scales, shear)
     M, scale = get_affine_matrix(target_size, degrees, translate, scales, shear)
 
 
-    img = cv2.warpAffine(img, M, dsize=target_size, borderValue=(114, 114, 114))
+    img = cv2.warpAffine(img, M, dsize=target_size, borderValue=border_value)
 
 
     # Transform label coordinates
     # Transform label coordinates
     if len(targets) > 0:
     if len(targets) > 0:
Discard
@@ -1,11 +1,29 @@
 import tempfile
 import tempfile
+import os
 import unittest
 import unittest
+from typing import Dict, Union, Any
+
+import numpy as np
+import pkg_resources
+from hydra import initialize_config_dir, compose
+from hydra.core.global_hydra import GlobalHydra
+from pydantic.main import deepcopy
 
 
 import super_gradients
 import super_gradients
+from super_gradients.training.dataloaders.dataloaders import _process_dataset_params
 from super_gradients.training.datasets import PascalVOCDetectionDataset, COCODetectionDataset
 from super_gradients.training.datasets import PascalVOCDetectionDataset, COCODetectionDataset
 from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
 from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
 from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
 from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
 from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
 from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
+from super_gradients.training.utils.hydra_utils import normalize_path
+
+
+class COCODetectionDataset6Channels(COCODetectionDataset):
+    def get_sample(self, index: int) -> Dict[str, Union[np.ndarray, Any]]:
+        img = self.get_resized_image(index)
+        img = np.concatenate((img, img), 2)
+        annotation = deepcopy(self.annotations[index])
+        return {"image": img, **annotation}
 
 
 
 
 class DatasetIntegrationTest(unittest.TestCase):
 class DatasetIntegrationTest(unittest.TestCase):
@@ -83,6 +101,21 @@ class DatasetIntegrationTest(unittest.TestCase):
             sampled_dataset = PascalVOCDetectionDataset(max_num_samples=max_num_samples, **self.pascal_base_config)
             sampled_dataset = PascalVOCDetectionDataset(max_num_samples=max_num_samples, **self.pascal_base_config)
             self.assertEqual(len(sampled_dataset), min(max_num_samples, len(full_dataset)))
             self.assertEqual(len(sampled_dataset), min(max_num_samples, len(full_dataset)))
 
 
+    def test_detection_dataset_transforms_with_unique_channel_count(self):
+        GlobalHydra.instance().clear()
+        sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
+        dataset_config = os.path.join("dataset_params", "coco_detection_dataset_params")
+        with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir), version_base="1.2"):
+            # config is relative to a module
+            cfg = compose(config_name=normalize_path(dataset_config))
+            dataset_params = _process_dataset_params(cfg, dict(), True)
+
+        coco_base_recipe_transforms = dataset_params["transforms"]
+        dataset_config = deepcopy(self.dataset_coco_base_config)
+        dataset_config["transforms"] = coco_base_recipe_transforms
+        dataset = COCODetectionDataset6Channels(**dataset_config)
+        self.assertEqual(dataset.__getitem__(0)[0].shape[0], 6)
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard