Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#393 Augment HSV will allow configuring bgr_channels in case there are oth…

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_augment_hsv_custom_channels
1 changed files with 6 additions and 5 deletions
  1. 6
    5
      src/super_gradients/training/transforms/transforms.py
@@ -711,16 +711,17 @@ class DetectionHSV(DetectionTransform):
     Detection HSV transform.
     Detection HSV transform.
     """
     """
 
 
-    def __init__(self, prob: float, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5):
+    def __init__(self, prob: float, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5, bgr_channels=(0, 1, 2)):
         super(DetectionHSV, self).__init__()
         super(DetectionHSV, self).__init__()
         self.prob = prob
         self.prob = prob
         self.hgain = hgain
         self.hgain = hgain
         self.sgain = sgain
         self.sgain = sgain
         self.vgain = vgain
         self.vgain = vgain
+        self.bgr_channels = bgr_channels
 
 
     def __call__(self, sample: dict) -> dict:
     def __call__(self, sample: dict) -> dict:
         if random.random() < self.prob:
         if random.random() < self.prob:
-            augment_hsv(sample["image"], self.hgain, self.sgain, self.vgain)
+            augment_hsv(sample["image"], self.hgain, self.sgain, self.vgain, self.bgr_channels)
         return sample
         return sample
 
 
 
 
@@ -1037,17 +1038,17 @@ def _mirror(image, boxes, prob=0.5):
     return image, flipped_boxes
     return image, flipped_boxes
 
 
 
 
-def augment_hsv(img: np.array, hgain: float, sgain: float, vgain: float):
+def augment_hsv(img: np.array, hgain: float, sgain: float, vgain: float, bgr_channels=(0, 1, 2)):
     hsv_augs = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain]  # random gains
     hsv_augs = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain]  # random gains
     hsv_augs *= np.random.randint(0, 2, 3)  # random selection of h, s, v
     hsv_augs *= np.random.randint(0, 2, 3)  # random selection of h, s, v
     hsv_augs = hsv_augs.astype(np.int16)
     hsv_augs = hsv_augs.astype(np.int16)
-    img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
+    img_hsv = cv2.cvtColor(img[..., bgr_channels], cv2.COLOR_BGR2HSV).astype(np.int16)
 
 
     img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
     img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
     img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
     img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
     img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
     img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
 
 
-    cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img)  # no return needed
+    img[..., bgr_channels] = cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR)  # no return needed
 
 
 
 
 def rescale_and_pad_to_size(img, input_size, swap=(2, 0, 1), pad_val=114):
 def rescale_and_pad_to_size(img, input_size, swap=(2, 0, 1), pad_val=114):
Discard