Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#260 feature/SG-175 New NMS

Merged
Ofri Masad merged 1 commits into Deci-AI:master from deci-ai:feature/SG-175_New_NMS
@@ -60,26 +60,40 @@ class YoloV5PostPredictionCallback(DetectionPostPredictionCallback):
     """Non-Maximum Suppression (NMS) module"""
     """Non-Maximum Suppression (NMS) module"""
 
 
     def __init__(self, conf: float = 0.001, iou: float = 0.6, classes: List[int] = None,
     def __init__(self, conf: float = 0.001, iou: float = 0.6, classes: List[int] = None,
-                 nms_type: NMS_Type = NMS_Type.ITERATIVE, max_predictions: int = 300):
+                 nms_type: NMS_Type = NMS_Type.ITERATIVE, max_predictions: int = 300,
+                 with_confidence: bool = True):
         """
         """
         :param conf: confidence threshold
         :param conf: confidence threshold
         :param iou: IoU threshold                                       (used in NMS_Type.ITERATIVE)
         :param iou: IoU threshold                                       (used in NMS_Type.ITERATIVE)
         :param classes: (optional list) filter by class                 (used in NMS_Type.ITERATIVE)
         :param classes: (optional list) filter by class                 (used in NMS_Type.ITERATIVE)
         :param nms_type: the type of nms to use (iterative or matrix)
         :param nms_type: the type of nms to use (iterative or matrix)
         :param max_predictions: maximum number of boxes to output       (used in NMS_Type.MATRIX)
         :param max_predictions: maximum number of boxes to output       (used in NMS_Type.MATRIX)
+        :param with_confidence: in NMS, whether to multiply objectness  (used in NMS_Type.ITERATIVE)
+                                score with class score
         """
         """
         super(YoloV5PostPredictionCallback, self).__init__()
         super(YoloV5PostPredictionCallback, self).__init__()
         self.conf = conf
         self.conf = conf
         self.iou = iou
         self.iou = iou
         self.classes = classes
         self.classes = classes
         self.nms_type = nms_type
         self.nms_type = nms_type
-        self.max_predictions = max_predictions
+        self.max_pred = max_predictions
+        self.with_confidence = with_confidence
 
 
     def forward(self, x, device: str = None):
     def forward(self, x, device: str = None):
+
         if self.nms_type == NMS_Type.ITERATIVE:
         if self.nms_type == NMS_Type.ITERATIVE:
-            return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
+            nms_result = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou,
+                                             with_confidence=self.with_confidence)
         else:
         else:
-            return matrix_non_max_suppression(x[0], conf_thres=self.conf, max_num_of_detections=self.max_predictions)
+            nms_result = matrix_non_max_suppression(x[0], conf_thres=self.conf,
+                                                    max_num_of_detections=self.max_pred)
+
+        return self._filter_max_predictions(nms_result)
+
+    def _filter_max_predictions(self, res: List) -> List:
+        res[:] = [im[:self.max_pred] if (im is not None and im.shape[0] > self.max_pred) else im for im in res]
+
+        return res
 
 
 
 
 class Concat(nn.Module):
 class Concat(nn.Module):
Discard
@@ -613,91 +613,52 @@ def box_iou(box1, box2):
     return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter)
     return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter)
 
 
 
 
-def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None,
-                        agnostic=False, multi_label_per_box=None):  # noqa: C901
-    """Performs Non-Maximum Suppression (NMS) on inference results
+def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6,
+                        multi_label_per_box: bool = True, with_confidence: bool = False):
+    """
+    Performs Non-Maximum Suppression (NMS) on inference results
         :param prediction: raw model prediction
         :param prediction: raw model prediction
         :param conf_thres: below the confidence threshold - prediction are discarded
         :param conf_thres: below the confidence threshold - prediction are discarded
         :param iou_thres: IoU threshold for the nms algorithm
         :param iou_thres: IoU threshold for the nms algorithm
-        :param merge: Merge boxes using weighted mean
-        :param classes: (optional list) filter by class
-        :param agnostic: Determines if is class agnostic. i.e. may display a box with 2 predictions
         :param multi_label_per_box: whether to use re-use each box with all possible labels
         :param multi_label_per_box: whether to use re-use each box with all possible labels
                                     (instead of the maximum confidence all confidences above threshold
                                     (instead of the maximum confidence all confidences above threshold
                                     will be sent to NMS); by default is set to True
                                     will be sent to NMS); by default is set to True
+        :param with_confidence: whether to multiply objectness score with class score.
+                                usually valid for Yolo models only.
         :return:  (x1, y1, x2, y2, object_conf, class_conf, class)
         :return:  (x1, y1, x2, y2, object_conf, class_conf, class)
     Returns:
     Returns:
          detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
          detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
     """
     """
-    # TODO: INVESTIGATE THE COMMENTED OUT PARTS AND DECIDE IF TO ERASE OR UNCOMMENT
-    number_of_classes = prediction[0].shape[1] - 5
-    candidates_above_thres = prediction[..., 4] > conf_thres
-
-    # Settings
-    # min_box_width_and_height = 2
-    max_box_width_and_height = 4096
-    max_num_of_detections = 300
-    require_redundant_detections = True
-    # when set to True (adds 0.5ms/img)
-    multi_label_per_box = multi_label_per_box if multi_label_per_box is not None else number_of_classes > 1
+    candidates_above_thres = prediction[..., 4] > conf_thres  # filter by confidence
     output = [None] * prediction.shape[0]
     output = [None] * prediction.shape[0]
+
     for image_idx, pred in enumerate(prediction):
     for image_idx, pred in enumerate(prediction):
-        # Apply constraints
-        # pred[((pred[..., 2:4] < min_box_width_and_height) | (pred[..., 2:4] > max_box_width_and_height)).any(1), 4] = 0  # width-height
-        pred = pred[candidates_above_thres[image_idx]]  # confidence
 
 
-        # If none remain process next image
-        if not pred.shape[0]:
+        pred = pred[candidates_above_thres[image_idx]]  # confident
+
+        if not pred.shape[0]:  # If none remain process next image
             continue
             continue
 
 
-        # Compute confidence = object_conf * class_conf
-        pred[:, 5:] *= pred[:, 4:5]
-        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
-        box = convert_xywh_bbox_to_xyxy(pred[:, :4])
+        if with_confidence:
+            pred[:, 5:] *= pred[:, 4:5]  # multiply objectness score with class score
+
+        box = convert_xywh_bbox_to_xyxy(pred[:, :4])  # xywh to xyxy
 
 
         # Detections matrix nx6 (xyxy, conf, cls)
         # Detections matrix nx6 (xyxy, conf, cls)
-        if multi_label_per_box:
+        if multi_label_per_box:  # try for all good confidence classes
             i, j = (pred[:, 5:] > conf_thres).nonzero(as_tuple=False).T
             i, j = (pred[:, 5:] > conf_thres).nonzero(as_tuple=False).T
             pred = torch.cat((box[i], pred[i, j + 5, None], j[:, None].float()), 1)
             pred = torch.cat((box[i], pred[i, j + 5, None], j[:, None].float()), 1)
+
         else:  # best class only
         else:  # best class only
             conf, j = pred[:, 5:].max(1, keepdim=True)
             conf, j = pred[:, 5:].max(1, keepdim=True)
             pred = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
             pred = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
 
 
-        # Filter by class
-        if classes:
-            pred = pred[(pred[:, 5:6] == torch.tensor(classes, device=pred.device)).any(1)]
-
-        # Apply finite constraint
-        # if not torch.isfinite(x).all():
-        #     x = x[torch.isfinite(x).all(1)]
-
-        # If none remain process next image
-        number_of_boxes = pred.shape[0]
-        if not number_of_boxes:
+        if not pred.shape[0]:  # If none remain process next image
             continue
             continue
 
 
-        # Sort by confidence
-        # x = x[x[:, 4].argsort(descending=True)]
-
-        # Batched NMS
-        # CREATE AN OFFSET OF THE PREDICTIVE BOX OF DIFFERENT CLASSES IF not agnostic
-        offset = pred[:, 5:6] * (0 if agnostic else max_box_width_and_height)
-        boxes, scores = pred[:, :4] + offset, pred[:, 4]
-        idx_to_keep = torch.ops.torchvision.nms(boxes, scores, iou_thres)
-        if idx_to_keep.shape[0] > max_num_of_detections:  # limit number of detections
-            idx_to_keep = idx_to_keep[:max_num_of_detections]
-        if merge and (1 < number_of_boxes < 3000):
-            try:  # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
-                iou = box_iou(boxes[idx_to_keep], boxes) > iou_thres  # iou matrix
-                box_weights = iou * scores[None]
-                # MERGED BOXES
-                pred[idx_to_keep, :4] = torch.mm(box_weights, pred[:, :4]).float() / box_weights.sum(1, keepdim=True)
-                if require_redundant_detections:
-                    idx_to_keep = idx_to_keep[iou.sum(1) > 1]
-            except RuntimeError:  # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
-                print(pred, idx_to_keep, pred.shape, idx_to_keep.shape)
-                pass
-
+        # Apply torch batched NMS algorithm
+        boxes, scores, cls_idx = pred[:, :4], pred[:, 4], pred[:, 5]
+        idx_to_keep = torchvision.ops.boxes.batched_nms(boxes, scores, cls_idx, iou_thres)
         output[image_idx] = pred[idx_to_keep]
         output[image_idx] = pred[idx_to_keep]
 
 
     return output
     return output
Discard