Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#208 detection transforms and their utils added

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:feature/SG-13_detection_tranforms
@@ -5,10 +5,17 @@ from typing import Optional, Union, Tuple, List, Sequence
 
 from PIL import Image, ImageFilter, ImageOps
 from torchvision import transforms as transforms
+import numpy as np
+import cv2
+from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.training.utils.detection_utils import get_mosaic_coordinate, \
+    adjust_box_anns, xyxy2cxcywh, cxcywh2xyxy, DetectionTargetsFormat
 
 image_resample = Image.BILINEAR
 mask_resample = Image.NEAREST
 
+logger = get_logger(__name__)
+
 
 class SegmentationTransform:
     def __call__(self, *args, **kwargs):
@@ -334,4 +341,651 @@ def _validate_fill_values_arguments(fill_mask: int, fill_image: Union[int, Tuple
     if min(fill_image) < 0 or max(fill_image) > 255 or fill_mask < 0 or fill_mask > 255:
         raise ValueError(f"Fill value must be a value from 0 to 255,"
                          f" found: fill_image = {fill_image}, fill_mask = {fill_mask}")
-    return fill_mask, fill_image
+    return fill_mask, fill_image
+
+
+class DetectionTransform:
+    """
+    Detection transform base class.
+
+    Complex transforms that require extra data loading can use the the additional_samples_count attribute in a
+     similar fashion to what's been done in COCODetectionDatasetYolox:
+
+    self._load_additional_inputs_for_transform(sample, transform)
+
+    # after the above call, sample["additional_samples"] holds a list of additional inputs and targets.
+
+    sample = transform(sample)
+
+
+
+    Attributes:
+        additional_samples_count: (int) additional samples to be loaded.
+        non_empty_targets: (bool) whether the additianl targets can have empty targets or not.
+    """
+
+    def __init__(self, additional_samples_count: int = 0, non_empty_targets: bool = False):
+        self.additional_samples_count = additional_samples_count
+        self.non_empty_targets = non_empty_targets
+
+    def __call__(self, sample: Union[dict, list]):
+        raise NotImplementedError
+
+    def __repr__(self):
+        return self.__class__.__name__ + str(self.__dict__).replace('{', '(').replace('}', ')')
+
+
+class DetectionMosaic(DetectionTransform):
+    """
+    DetectionMosaic detection transform
+
+    Attributes:
+        input_dim: (tuple) input dimension.
+        prob: (float) probability of applying mosaic.
+        enable_mosaic: (bool) whether to apply mosaic at all (regardless of prob) (default=True).
+
+    """
+
+    def __init__(self, input_dim: tuple, prob: float = 1., enable_mosaic: bool = True):
+        super(DetectionMosaic, self).__init__(additional_samples_count=3)
+        self.prob = prob
+        self.input_dim = input_dim
+        self.enable_mosaic = enable_mosaic
+
+    def close(self):
+        self.additional_samples_count = 0
+        self.enable_mosaic = False
+
+    def __call__(self, sample: Union[dict, list]):
+        if self.enable_mosaic and random.random() < self.prob:
+            mosaic_labels = []
+            mosaic_labels_seg = []
+            input_h, input_w = self.input_dim[0], self.input_dim[1]
+
+            # yc, xc = s, s  # mosaic center x, y
+            yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
+            xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
+
+            # 3 additional samples, total of 4
+            all_samples = [sample] + sample["additional_samples"]
+
+            for i_mosaic, mosaic_sample in enumerate(all_samples):
+                img, _labels, _labels_seg, _ = mosaic_sample["image"], mosaic_sample["target"], mosaic_sample[
+                    "target_seg"], mosaic_sample["id"]
+                h0, w0 = img.shape[:2]  # orig hw
+                scale = min(1. * input_h / h0, 1. * input_w / w0)
+                img = cv2.resize(
+                    img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR
+                )
+                # generate output mosaic image
+                (h, w, c) = img.shape[:3]
+                if i_mosaic == 0:
+                    mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)
+
+                # suffix l means large image, while s means small image in mosaic aug.
+                (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(i_mosaic, xc, yc, w, h,
+                                                                                           input_h, input_w)
+
+                mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
+                padw, padh = l_x1 - s_x1, l_y1 - s_y1
+
+                labels = _labels.copy()
+                labels_seg = _labels_seg.copy()
+
+                # Normalized xywh to pixel xyxy format
+                if _labels.size > 0:
+                    labels[:, 0] = scale * _labels[:, 0] + padw
+                    labels[:, 1] = scale * _labels[:, 1] + padh
+                    labels[:, 2] = scale * _labels[:, 2] + padw
+                    labels[:, 3] = scale * _labels[:, 3] + padh
+
+                    labels_seg[:, ::2] = scale * labels_seg[:, ::2] + padw
+                    labels_seg[:, 1::2] = scale * labels_seg[:, 1::2] + padh
+                mosaic_labels_seg.append(labels_seg)
+                mosaic_labels.append(labels)
+
+            if len(mosaic_labels):
+                mosaic_labels = np.concatenate(mosaic_labels, 0)
+                np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
+                np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
+                np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
+                np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
+                mosaic_labels_seg = np.concatenate(mosaic_labels_seg, 0)
+                np.clip(mosaic_labels_seg[:, ::2], 0, 2 * input_w, out=mosaic_labels_seg[:, ::2])
+                np.clip(mosaic_labels_seg[:, 1::2], 0, 2 * input_h, out=mosaic_labels_seg[:, 1::2])
+
+            sample = {"image": mosaic_img, "target": mosaic_labels, "target_seg": mosaic_labels_seg,
+                      "info": (mosaic_img.shape[1], mosaic_img.shape[0]), "id": sample["id"]}
+        return sample
+
+
+class DetectionRandomAffine(DetectionTransform):
+    """
+    DetectionRandomAffine detection transform
+
+    Attributes:
+     target_size: (tuple) desired output shape.
+
+     degrees:  (Union[tuple, float]) degrees for random rotation, when float the random values are drawn uniformly
+        from (-degrees, degrees)
+
+     translate:  (Union[tuple, float]) translate size (in pixels) for random translation, when float the random values
+        are drawn uniformly from (-translate, translate)
+
+     scales: (Union[tuple, float]) values for random rescale, when float the random values are drawn uniformly
+        from (0.1-scales, 0.1+scales)
+
+     shear: (Union[tuple, float]) degrees for random shear, when float the random values are drawn uniformly
+        from (shear, shear)
+
+    enable: (bool) whether to apply the below transform at all.
+
+    """
+
+    def __init__(self, degrees=10, translate=0.1, scales=0.1, shear=10, target_size=(640, 640)):
+        super(DetectionRandomAffine, self).__init__()
+        self.degrees = degrees
+        self.translate = translate
+        self.scale = scales
+        self.shear = shear
+        self.target_size = target_size
+        self.enable = True
+
+    def close(self):
+        self.enable = False
+
+    def __call__(self, sample: dict):
+        if self.enable:
+            img, target = random_affine(
+                sample["image"],
+                sample["target"],
+                sample["target_seg"],
+                target_size=self.target_size,
+                degrees=self.degrees,
+                translate=self.translate,
+                scales=self.scale,
+                shear=self.shear,
+            )
+            sample["image"] = img
+            sample["target"] = target
+        return sample
+
+
+class DetectionMixup(DetectionTransform):
+    """
+    Mixup detection transform
+
+    Attributes:
+        input_dim: (tuple) input dimension.
+        mixup_scale: (tuple) scale range for the additional loaded image for mixup.
+        prob: (float) probability of applying mixup.
+        enable_mixup: (bool) whether to apply mixup at all (regardless of prob) (default=True).
+    """
+
+    def __init__(self, input_dim, mixup_scale, prob=1., enable_mixup=True):
+        super(DetectionMixup, self).__init__(additional_samples_count=1, non_empty_targets=True)
+        self.input_dim = input_dim
+        self.mixup_scale = mixup_scale
+        self.prob = prob
+        self.enable_mixup = enable_mixup
+
+    def close(self):
+        self.additional_samples_count = 0
+        self.enable_mixup = False
+
+    def __call__(self, sample: dict):
+        if self.enable_mixup and random.random() < self.prob:
+            origin_img, origin_labels = sample["image"], sample["target"]
+            cp_sample = sample["additional_samples"][0]
+            img, cp_labels = cp_sample["image"], cp_sample["target"]
+
+            img, cp_labels = _mirror(img, cp_labels, 0.5)
+            jit_factor = random.uniform(*self.mixup_scale)
+
+            if len(img.shape) == 3:
+                cp_img = np.ones((self.input_dim[0], self.input_dim[1], 3), dtype=np.uint8) * 114
+            else:
+                cp_img = np.ones(self.input_dim, dtype=np.uint8) * 114
+
+            cp_scale_ratio = min(self.input_dim[0] / img.shape[0], self.input_dim[1] / img.shape[1])
+            resized_img = cv2.resize(
+                img,
+                (int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)),
+                interpolation=cv2.INTER_LINEAR,
+            )
+
+            cp_img[: int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio)] = resized_img
+
+            cp_img = cv2.resize(
+                cp_img,
+                (int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)),
+            )
+            cp_scale_ratio *= jit_factor
+
+
+            origin_h, origin_w = cp_img.shape[:2]
+            target_h, target_w = origin_img.shape[:2]
+            padded_img = np.zeros(
+                (max(origin_h, target_h), max(origin_w, target_w), 3), dtype=np.uint8
+            )
+            padded_img[:origin_h, :origin_w] = cp_img
+
+            x_offset, y_offset = 0, 0
+            if padded_img.shape[0] > target_h:
+                y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
+            if padded_img.shape[1] > target_w:
+                x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
+            padded_cropped_img = padded_img[y_offset: y_offset + target_h, x_offset: x_offset + target_w]
+
+            cp_bboxes_origin_np = adjust_box_anns(
+                cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h
+            )
+            cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
+            cp_bboxes_transformed_np[:, 0::2] = np.clip(
+                cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w
+            )
+            cp_bboxes_transformed_np[:, 1::2] = np.clip(
+                cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h
+            )
+
+            cls_labels = cp_labels[:, 4:5].copy()
+            box_labels = cp_bboxes_transformed_np
+            labels = np.hstack((box_labels, cls_labels))
+            origin_labels = np.vstack((origin_labels, labels))
+            origin_img = origin_img.astype(np.float32)
+            origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32)
+
+            sample["image"], sample["target"] = origin_img.astype(np.uint8), origin_labels
+        return sample
+
+
+class DetectionPaddedRescale(DetectionTransform):
+    """
+    Preprocessing transform to be applied last of all transforms for validation.
+
+    Image- Rescales and pads to self.input_dim.
+    Targets- pads targets to max_targets, moves the class label to first index, converts boxes format- xyxy -> cxcywh.
+
+    Attributes:
+        input_dim: (tuple) final input dimension (default=(640,640))
+        swap: image axis's to be rearranged.
+
+    """
+
+    def __init__(self, input_dim, swap=(2, 0, 1), max_targets=50):
+        self.swap = swap
+        self.input_dim = input_dim
+        self.max_targets = max_targets
+
+    def __call__(self, sample):
+        img, target = sample["image"], sample["target"]
+        if len(target) == 0:
+            new_target = np.zeros((self.max_targets, 5), dtype=np.float32)
+        else:
+            new_target = target.copy()
+
+        boxes = new_target[:, :4]
+        labels = new_target[:, 4]
+        img, r = rescale_and_pad_to_size(img, self.input_dim, self.swap)
+        boxes = xyxy2cxcywh(boxes)
+        boxes *= r
+        boxes = cxcywh2xyxy(boxes)
+        new_target = np.concatenate((boxes, labels[:, np.newaxis]), 1)
+
+        sample["image"] = img
+        sample["target"] = new_target
+        return sample
+
+
+class DetectionHorizontalFlip(DetectionTransform):
+    """
+    Horizontal Flip for Detection
+
+    Attributes:
+        prob: float: probability of applying HSV transform
+        max_targets: int: max objects in single image, padding target to this size in case of empty image.
+    """
+
+    def __init__(self, prob, max_targets: int = 120):
+        super(DetectionHorizontalFlip, self).__init__()
+        self.prob = prob
+        self.max_targets = max_targets
+
+    def __call__(self, sample):
+        image, targets = sample["image"], sample["target"]
+        boxes = targets[:, :4]
+        if len(boxes) == 0:
+            targets = np.zeros((self.max_targets, 5), dtype=np.float32)
+            boxes = targets[:, :4]
+        image, boxes = _mirror(image, boxes, self.prob)
+        sample["image"] = image
+        return sample
+
+
+class DetectionHSV(DetectionTransform):
+    """
+    Detection HSV transform.
+    """
+
+    def __init__(self, prob):
+        super(DetectionHSV, self).__init__()
+        self.prob = prob
+
+    def __call__(self, sample):
+        if random.random() < self.prob:
+            augment_hsv(sample["image"])
+        return sample
+
+
+class DetectionTargetsFormatTransform(DetectionTransform):
+    """
+    Detection targets format transform
+
+    Converts targets in input_format to output_format.
+    Attributes:
+        input_format: DetectionTargetsFormat: input target format
+        output_format: DetectionTargetsFormat: output target format
+        min_bbox_edge_size: int: bboxes with edge size lower then this values will be removed.
+        max_targets: int: max objects in single image, padding target to this size.
+    """
+
+    def __init__(self, input_format: DetectionTargetsFormat = DetectionTargetsFormat.XYXY_LABEL,
+                 output_format: DetectionTargetsFormat = DetectionTargetsFormat.LABEL_CXCYWH,
+                 min_bbox_edge_size: float = 1, max_targets: int = 120):
+        super(DetectionTargetsFormatTransform, self).__init__()
+        self.input_format = input_format
+        self.output_format = output_format
+        self.min_bbox_edge_size = min_bbox_edge_size
+        self.max_targets = max_targets
+
+    def __call__(self, sample):
+        normalized_input = "NORMALIZED" in self.input_format.value
+        normalized_output = "NORMALIZED" in self.output_format.value
+        normalize = not normalized_input and normalized_output
+        denormalize = normalized_input and not normalized_output
+
+        label_first_in_input = self.input_format.value.split("_")[0] == "LABEL"
+        label_first_in_output = self.output_format.value.split("_")[0] == "LABEL"
+
+        input_xyxy_format = "XYXY" in self.input_format.value
+        output_xyxy_format = "XYXY" in self.output_format.value
+        convert2xyxy = not input_xyxy_format and output_xyxy_format
+        convert2cxcy = input_xyxy_format and not output_xyxy_format
+
+        image, targets = sample["image"], sample["target"]
+
+        if label_first_in_input:
+            boxes = targets[:, 1:]
+            labels = targets[:, 0]
+        else:
+            boxes = targets[:, :4]
+            labels = targets[:, 4]
+
+        if convert2cxcy:
+            boxes = xyxy2cxcywh(boxes)
+        elif convert2xyxy:
+            boxes = cxcywh2xyxy(boxes)
+
+        _, h, w = image.shape
+
+        if normalize:
+            boxes[:, 0] = boxes[:, 0] / w
+            boxes[:, 1] = boxes[:, 1] / h
+            boxes[:, 2] = boxes[:, 2] / w
+            boxes[:, 3] = boxes[:, 3] / h
+
+        elif denormalize:
+            boxes[:, 0] = boxes[:, 0] * w
+            boxes[:, 1] = boxes[:, 1] * h
+            boxes[:, 2] = boxes[:, 2] * w
+            boxes[:, 3] = boxes[:, 3] * h
+
+        min_bbox_edge_size = self.min_bbox_edge_size / max(w, h) if normalized_output else self.min_bbox_edge_size
+
+        cxcywh_boxes = boxes if not output_xyxy_format else xyxy2cxcywh(boxes.copy())
+
+        mask_b = np.minimum(cxcywh_boxes[:, 2], cxcywh_boxes[:, 3]) > min_bbox_edge_size
+        boxes_t = boxes[mask_b]
+        labels_t = labels[mask_b]
+
+        labels_t = np.expand_dims(labels_t, 1)
+        targets_t = np.hstack((labels_t, boxes_t)) if label_first_in_output else np.hstack((boxes_t, labels_t))
+        padded_targets = np.zeros((self.max_targets, 5))
+        padded_targets[range(len(targets_t))[: self.max_targets]] = targets_t[: self.max_targets]
+        padded_targets = np.ascontiguousarray(padded_targets, dtype=np.float32)
+
+        sample["target"] = padded_targets
+        return sample
+
+
+def get_aug_params(value: Union[tuple, float], center: float = 0):
+    """
+    Generates a random value for augmentations as described below
+
+    :param value: Union[tuple, float] defines the range of values for generation. Wen tuple-
+     drawn uniformly between (value[0], value[1]), and (center - value, center + value) when float
+    :param center: float, defines center to subtract when value is float.
+    :return: generated value
+    """
+    if isinstance(value, float):
+        return random.uniform(center - value, center + value)
+    elif len(value) == 2:
+        return random.uniform(value[0], value[1])
+    else:
+        raise ValueError(
+            "Affine params should be either a sequence containing two values\
+                          or single float values. Got {}".format(
+                value
+            )
+        )
+
+
+def get_affine_matrix(
+        target_size,
+        degrees=10,
+        translate=0.1,
+        scales=0.1,
+        shear=10,
+):
+    """
+    Returns a random affine transform matrix.
+
+    :param target_size: (tuple) desired output shape.
+
+    :param degrees:  (Union[tuple, float]) degrees for random rotation, when float the random values are drawn uniformly
+     from (-degrees, degrees)
+
+    :param translate:  (Union[tuple, float]) translate size (in pixels) for random translation, when float the random values
+     are drawn uniformly from (-translate, translate)
+
+    :param scales: (Union[tuple, float]) values for random rescale, when float the random values are drawn uniformly
+     from (0.1-scales, 0.1+scales)
+
+    :param shear: (Union[tuple, float]) degrees for random shear, when float the random values are drawn uniformly
+     from (shear, shear)
+
+    :return: affine_transform_matrix, drawn_scale
+    """
+    twidth, theight = target_size
+
+    # Rotation and Scale
+    angle = get_aug_params(degrees)
+    scale = get_aug_params(scales, center=1.0)
+
+    if scale <= 0.0:
+        raise ValueError("Argument scale should be positive")
+
+    R = cv2.getRotationMatrix2D(angle=angle, center=(0, 0), scale=scale)
+
+    M = np.ones([2, 3])
+    # Shear
+    shear_x = math.tan(get_aug_params(shear) * math.pi / 180)
+    shear_y = math.tan(get_aug_params(shear) * math.pi / 180)
+
+    M[0] = R[0] + shear_y * R[1]
+    M[1] = R[1] + shear_x * R[0]
+
+    # Translation
+    translation_x = get_aug_params(translate) * twidth  # x translation (pixels)
+    translation_y = get_aug_params(translate) * theight  # y translation (pixels)
+
+    M[0, 2] = translation_x
+    M[1, 2] = translation_y
+
+    return M, scale
+
+
+def apply_affine_to_bboxes(targets, targets_seg, target_size, M):
+    num_gts = len(targets)
+    twidth, theight = target_size
+    seg_is_present_mask = np.logical_or.reduce(~np.isnan(targets_seg), axis=1)
+    num_gts_masks = seg_is_present_mask.sum()
+    num_gts_boxes = num_gts - num_gts_masks
+
+    if num_gts_boxes:
+        # warp corner points
+        corner_points = np.ones((num_gts_boxes * 4, 3))
+        # x1y1, x2y2, x1y2, x2y1
+        corner_points[:, :2] = targets[~seg_is_present_mask][:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(num_gts_boxes * 4, 2)
+        corner_points = corner_points @ M.T  # apply affine transform
+        corner_points = corner_points.reshape(num_gts_boxes, 8)
+
+        # create new boxes
+        corner_xs = corner_points[:, 0::2]
+        corner_ys = corner_points[:, 1::2]
+        new_bboxes = (np.concatenate(
+            (np.min(corner_xs, 1), np.min(corner_ys, 1),
+             np.max(corner_xs, 1), np.max(corner_ys, 1))
+        ).reshape(4, -1).T)
+    else:
+        new_bboxes = np.ones((0, 4), dtype=np.float)
+
+    if num_gts_masks:
+        # warp segmentation points
+        num_seg_points = targets_seg.shape[1] // 2
+        corner_points_seg = np.ones((num_gts_masks * num_seg_points, 3))
+        corner_points_seg[:, :2] = targets_seg[seg_is_present_mask].reshape(num_gts_masks * num_seg_points, 2)
+        corner_points_seg = corner_points_seg @ M.T
+        corner_points_seg = corner_points_seg.reshape(num_gts_masks, num_seg_points * 2)
+
+        # create new boxes
+        seg_points_xs = corner_points_seg[:, 0::2]
+        seg_points_ys = corner_points_seg[:, 1::2]
+        new_tight_bboxes = (np.concatenate(
+            (np.nanmin(seg_points_xs, 1), np.nanmin(seg_points_ys, 1),
+             np.nanmax(seg_points_xs, 1), np.nanmax(seg_points_ys, 1))
+        ).reshape(4, -1).T)
+    else:
+        new_tight_bboxes = np.ones((0, 4), dtype=np.float)
+
+    targets[~seg_is_present_mask, :4] = new_bboxes
+    targets[seg_is_present_mask, :4] = new_tight_bboxes
+
+    # clip boxes
+    targets[:, [0, 2]] = targets[:, [0, 2]].clip(0, twidth)
+    targets[:, [1, 3]] = targets[:, [1, 3]].clip(0, theight)
+
+    return targets
+
+
+def random_affine(
+        img,
+        targets=(),
+        targets_seg=(),
+        target_size=(640, 640),
+        degrees=10,
+        translate=0.1,
+        scales=0.1,
+        shear=10,
+):
+    """
+    Performs random affine transform to img, targets
+
+    :param img: (array) input image.
+
+    :param targets: (array) input target.
+
+    :param targets_seg: (array) targets derived from segmentation masks.
+
+    :param target_size: (tuple) desired output shape.
+
+    :param degrees:  (Union[tuple, float]) degrees for random rotation, when float the random values are drawn uniformly
+     from (-degrees, degrees)
+
+    :param translate:  (Union[tuple, float]) translate size (in pixels) for random translation, when float the random values
+     are drawn uniformly from (-translate, translate)
+
+    :param scales: (Union[tuple, float]) values for random rescale, when float the random values are drawn uniformly
+     from (0.1-scales, 0.1+scales)
+
+    :param shear: (Union[tuple, float]) degrees for random shear, when float the random values are drawn uniformly
+     from (shear, shear)
+
+    :return:
+    """
+    M, scale = get_affine_matrix(target_size, degrees, translate, scales, shear)
+
+    img = cv2.warpAffine(img, M, dsize=target_size, borderValue=(114, 114, 114))
+
+    # Transform label coordinates
+    if len(targets) > 0:
+        targets = apply_affine_to_bboxes(targets, targets_seg, target_size, M)
+
+    return img, targets
+
+
+def _mirror(image, boxes, prob=0.5):
+    """
+    Horizontal flips image and bboxes with probability prob.
+
+    :param image: (np.array) image to be flipped.
+    :param boxes: (np.array) bboxes to be modified.
+    :param prob: probability to perform flipping.
+    :return: flipped_image, flipped_bboxes
+    """
+    _, width, _ = image.shape
+    if random.random() < prob:
+        image = image[:, ::-1]
+        boxes[:, 0::2] = width - boxes[:, 2::-2]
+    return image, boxes
+
+
+def augment_hsv(img, hgain=5, sgain=30, vgain=30):
+    hsv_augs = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain]  # random gains
+    hsv_augs *= np.random.randint(0, 2, 3)  # random selection of h, s, v
+    hsv_augs = hsv_augs.astype(np.int16)
+    img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
+
+    img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
+    img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
+    img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
+
+    cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img)  # no return needed
+
+
+def rescale_and_pad_to_size(img, input_size, swap=(2, 0, 1), pad_val=114):
+    """
+    Rescales image according to minimum ratio between the target height /image height, target width / image width,
+    and pads the image to the target size.
+
+    :param img: Image to be rescaled
+    :param input_size: Target size
+    :param swap: Axis's to be rearranged.
+    :return: rescaled image, ratio
+    """
+    if len(img.shape) == 3:
+        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * pad_val
+    else:
+        padded_img = np.ones(input_size, dtype=np.uint8) * pad_val
+
+    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+    resized_img = cv2.resize(
+        img,
+        (int(img.shape[1] * r), int(img.shape[0] * r)),
+        interpolation=cv2.INTER_LINEAR,
+    ).astype(np.uint8)
+    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+    padded_img = padded_img.transpose(swap)
+    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+    return padded_img, r
Discard
@@ -7,6 +7,7 @@ from typing import Callable, List, Union, Tuple
 import cv2
 from deprecated import deprecated
 from scipy.cluster.vq import kmeans
+from torch.utils.data._utils.collate import default_collate
 from tqdm import tqdm
 import matplotlib.pyplot as plt
 from PIL import Image
@@ -20,6 +21,25 @@ from super_gradients.common.abstractions.abstract_logger import get_logger
 from omegaconf import ListConfig
 
 
+class DetectionTargetsFormat(Enum):
+    """
+    Enum class for the different detection output formats
+
+    When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).
+
+    For example:
+    LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]
+    """
+    LABEL_XYXY = "LABEL_XYXY"
+    XYXY_LABEL = "XYXY_LABEL"
+    LABEL_NORMALIZED_XYXY = "LABEL_NORMALIZED_XYXY"
+    NORMALIZED_XYXY_LABEL = "NORMALIZED_XYXY_LABEL"
+    LABEL_CXCYWH = "LABEL_CXCYWH"
+    CXCYWH_LABEL = "CXCYWH_LABEL"
+    LABEL_NORMALIZED_CXCYWH = "LABEL_NORMALIZED_CXCYWH"
+    NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"
+
+
 def base_detection_collate_fn(batch):
     """
     Batch Processing helper function for detection training/testing.
@@ -756,7 +776,8 @@ def calc_batch_prediction_accuracy(output: torch.Tensor, targets: torch.Tensor,
         if pred is None:
             if labels_num:
                 batch_metrics.append(
-                    (np.zeros((0, num_ious), dtype=np.bool), np.array([], dtype=np.float32), np.array([], dtype=np.float32), target_class))
+                    (np.zeros((0, num_ious), dtype=np.bool), np.array([], dtype=np.float32),
+                     np.array([], dtype=np.float32), target_class))
             continue
 
         # CHANGE bboxes TO FIT THE IMAGE SIZE
@@ -943,7 +964,8 @@ def plot_coco_datasaet_images_with_detections(data_loader, num_images_to_plot=1)
         ns = np.ceil(batch_size ** 0.5)
 
         for i in range(batch_size):
-            boxes = convert_xywh_bbox_to_xyxy(torch.from_numpy(targets[targets[:, 0] == i, 2:6])).cpu().detach().numpy().T
+            boxes = convert_xywh_bbox_to_xyxy(
+                torch.from_numpy(targets[targets[:, 0] == i, 2:6])).cpu().detach().numpy().T
             boxes[[0, 2]] *= w
             boxes[[1, 3]] *= h
             plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
@@ -1164,3 +1186,97 @@ class Anchors(nn.Module):
 
     def __repr__(self):
         return f"anchors_list: {self.__anchors_list} strides: {self.__strides}"
+
+
+def xyxy2cxcywh(bboxes):
+    """
+    Transforms bboxes from xyxy format to centerized xy wh format
+    :param bboxes: array, shaped (nboxes, 4)
+    :return: modified bboxes
+    """
+    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
+    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
+    bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
+    bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
+    return bboxes
+
+
+def cxcywh2xyxy(bboxes):
+    """
+    Transforms bboxes from centerized xy wh format to xyxy format
+    :param bboxes: array, shaped (nboxes, 4)
+    :return: modified bboxes
+    """
+    bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
+    bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
+    bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1]
+    bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0]
+    return bboxes
+
+
+def get_mosaic_coordinate(mosaic_index, xc, yc, w, h, input_h, input_w):
+    """
+    Returns the mosaic coordinates of final mosaic image according to mosaic image index.
+
+    :param mosaic_index: (int) mosaic image index
+    :param xc: (int) center x coordinate of the entire mosaic grid.
+    :param yc: (int) center y coordinate of the entire mosaic grid.
+    :param w: (int) width of bbox
+    :param h: (int) height of bbox
+    :param input_h: (int) image input height (should be 1/2 of the final mosaic output image height).
+    :param input_w: (int) image input width (should be 1/2 of the final mosaic output image width).
+    :return: (x1, y1, x2, y2), (x1s, y1s, x2s, y2s) where (x1, y1, x2, y2) are the coordinates in the final mosaic
+        output image, and (x1s, y1s, x2s, y2s) are the coordinates in the placed image.
+    """
+    # index0 to top left part of image
+    if mosaic_index == 0:
+        x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
+        small_coord = w - (x2 - x1), h - (y2 - y1), w, h
+    # index1 to top right part of image
+    elif mosaic_index == 1:
+        x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
+        small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
+    # index2 to bottom left part of image
+    elif mosaic_index == 2:
+        x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
+        small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
+    # index2 to bottom right part of image
+    elif mosaic_index == 3:
+        x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h)  # noqa
+        small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
+    return (x1, y1, x2, y2), small_coord
+
+
+def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
+    """
+    Adjusts the bbox annotations of rescaled, padded image.
+
+    :param bbox: (np.array) bbox to modify.
+    :param scale_ratio: (float) scale ratio between rescale output image and original one.
+    :param padw: (int) width padding size.
+    :param padh: (int) height padding size.
+    :param w_max: (int) width border.
+    :param h_max: (int) height border
+    :return: modified bbox (np.array)
+    """
+    bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max)
+    bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max)
+    return bbox
+
+
+class YoloXCollateFN:
+    """
+    Collate function for Yolox training
+    """
+    def __call__(self, data):
+        batch = default_collate(data)
+        ims = batch[0]
+        targets = batch[1]
+        nlabel = (targets.sum(dim=2) > 0).sum(dim=1)  # number of objects
+        targets_merged = []
+        for i in range(targets.shape[0]):
+            targets_im = targets[i, :nlabel[i]]
+            batch_column = targets.new_ones((targets_im.shape[0], 1)) * i
+            targets_merged.append(torch.cat((batch_column, targets_im), 1))
+        targets = torch.cat(targets_merged, 0)
+        return ims, targets
Discard
@@ -16,6 +16,7 @@ from tests.unit_tests.kd_model_test import KDModelTest
 from tests.unit_tests.dice_loss_test import DiceLossTest
 from tests.unit_tests.vit_unit_test import TestViT
 from tests.unit_tests.lr_cooldown_test import LRCooldownTest
+from tests.unit_tests.detection_targets_format_transform_test import DetectionTargetsTransformTest
 
 
 class CoreUnitTestSuiteRunner:
@@ -55,6 +56,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(KDModelTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LRCooldownTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionTargetsTransformTest))
 
     def _add_modules_to_end_to_end_tests_suite(self):
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  1. import numpy as np
  2. import unittest
  3. from super_gradients.training.transforms.transforms import DetectionTargetsFormatTransform
  4. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  5. class DetectionTargetsTransformTest(unittest.TestCase):
  6. def setUp(self) -> None:
  7. self.image = np.zeros((3, 100, 200))
  8. def test_label_first_2_label_last(self):
  9. input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
  10. output = np.array([[50, 10, 20, 30, 40]], dtype=np.float32)
  11. transform = DetectionTargetsFormatTransform(max_targets=1,
  12. input_format=DetectionTargetsFormat.XYXY_LABEL,
  13. output_format=DetectionTargetsFormat.LABEL_XYXY)
  14. sample = {"image": self.image, "target": input}
  15. self.assertTrue(np.array_equal(transform(sample)["target"], output))
  16. def test_xyxy_2_normalized_xyxy(self):
  17. input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
  18. _, h, w = self.image.shape
  19. output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
  20. transform = DetectionTargetsFormatTransform(max_targets=1,
  21. input_format=DetectionTargetsFormat.LABEL_XYXY,
  22. output_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY)
  23. sample = {"image": self.image, "target": input}
  24. t_output = transform(sample)["target"]
  25. self.assertTrue(np.array_equal(output, t_output))
  26. def test_xyxy_2_cxcywh(self):
  27. input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
  28. _, h, w = self.image.shape
  29. output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
  30. transform = DetectionTargetsFormatTransform(max_targets=1,
  31. input_format=DetectionTargetsFormat.LABEL_XYXY,
  32. output_format=DetectionTargetsFormat.LABEL_CXCYWH)
  33. sample = {"image": self.image, "target": input}
  34. t_output = transform(sample)["target"]
  35. self.assertTrue(np.array_equal(output, t_output))
  36. def test_xyxy_2_normalized_cxcywh(self):
  37. input = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
  38. _, h, w = self.image.shape
  39. output = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
  40. transform = DetectionTargetsFormatTransform(max_targets=1,
  41. input_format=DetectionTargetsFormat.LABEL_XYXY,
  42. output_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH)
  43. sample = {"image": self.image, "target": input}
  44. t_output = transform(sample)["target"]
  45. self.assertTrue(np.array_equal(output, t_output))
  46. def test_normalized_xyxy_2_cxcywh(self):
  47. _, h, w = self.image.shape
  48. input = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
  49. output = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
  50. transform = DetectionTargetsFormatTransform(max_targets=1,
  51. input_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY,
  52. output_format=DetectionTargetsFormat.LABEL_CXCYWH)
  53. sample = {"image": self.image, "target": input}
  54. t_output = transform(sample)["target"]
  55. self.assertTrue(np.allclose(output, t_output))
  56. def test_normalized_xyxy_2_normalized_cxcywh(self):
  57. _, h, w = self.image.shape
  58. input = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
  59. output = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
  60. transform = DetectionTargetsFormatTransform(max_targets=1,
  61. input_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY,
  62. output_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH)
  63. sample = {"image": self.image, "target": input}
  64. t_output = transform(sample)["target"]
  65. self.assertTrue(np.allclose(output, t_output))
  66. def test_cxcywh_2_xyxy(self):
  67. output = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
  68. input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
  69. transform = DetectionTargetsFormatTransform(max_targets=1,
  70. input_format=DetectionTargetsFormat.LABEL_CXCYWH,
  71. output_format=DetectionTargetsFormat.LABEL_XYXY)
  72. sample = {"image": self.image, "target": input}
  73. t_output = transform(sample)["target"]
  74. self.assertTrue(np.array_equal(output, t_output))
  75. def test_cxcywh_2_normalized_xyxy(self):
  76. _, h, w = self.image.shape
  77. output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
  78. input = np.array([[10, 30, 40, 20, 20]], dtype=np.float32)
  79. transform = DetectionTargetsFormatTransform(max_targets=1,
  80. input_format=DetectionTargetsFormat.LABEL_CXCYWH,
  81. output_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY)
  82. sample = {"image": self.image, "target": input}
  83. t_output = transform(sample)["target"]
  84. self.assertTrue(np.array_equal(output, t_output))
  85. def test_normalized_cxcywh_2_xyxy(self):
  86. _, h, w = self.image.shape
  87. input = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
  88. output = np.array([[10, 20, 30, 40, 50]], dtype=np.float32)
  89. transform = DetectionTargetsFormatTransform(max_targets=1,
  90. input_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH,
  91. output_format=DetectionTargetsFormat.LABEL_XYXY)
  92. sample = {"image": self.image, "target": input}
  93. t_output = transform(sample)["target"]
  94. self.assertTrue(np.allclose(output, t_output))
  95. def test_normalized_cxcywh_2_normalized_xyxy(self):
  96. _, h, w = self.image.shape
  97. output = np.array([[10, 20 / w, 30 / h, 40 / w, 50 / h]], dtype=np.float32)
  98. input = np.array([[10, 30 / w, 40 / h, 20 / w, 20 / h]], dtype=np.float32)
  99. transform = DetectionTargetsFormatTransform(max_targets=1,
  100. input_format=DetectionTargetsFormat.LABEL_NORMALIZED_CXCYWH,
  101. output_format=DetectionTargetsFormat.LABEL_NORMALIZED_XYXY)
  102. sample = {"image": self.image, "target": input}
  103. t_output = transform(sample)["target"]
  104. self.assertTrue(np.allclose(output, t_output))
  105. if __name__ == '__main__':
  106. unittest.main()
Discard