Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#21009 `ultralytics 8.3.154` Refactor `Validator` and `Metrics` classes

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:validator-cleanup
@@ -143,8 +143,11 @@ To train a YOLO11 model using JupyterLab:
 5. Visualize training results using JupyterLab's built-in plotting capabilities:
 
     ```python
-    %matplotlib inline
+    import matplotlib
+
     from ultralytics.utils.plotting import plot_results
+
+    matplotlib.use("inline")  # or 'notebook' for interactive
     plot_results(results)
     ```
 
Discard
@@ -325,7 +325,7 @@ To use YOLOv7 ONNX model with Ultralytics:
 
 2. Install the `TensorRT` Python package:
 
-    ```python
+    ```bash
     pip install tensorrt
     ```
 
Discard
@@ -43,14 +43,6 @@ keywords: ultralytics, plotting, utilities, documentation, data visualization, a
 
 <br><br><hr><br>
 
-## ::: ultralytics.utils.plotting.output_to_target
-
-<br><br><hr><br>
-
-## ::: ultralytics.utils.plotting.output_to_rotated_target
-
-<br><br><hr><br>
-
 ## ::: ultralytics.utils.plotting.feature_visualization
 
 <br><br>
Discard
@@ -76,16 +76,21 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
 
     Ultralytics YOLO classification uses [torchvision.transforms.RandomResizedCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) for training augmentation and [torchvision.transforms.CenterCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html) for validation/inference.
     For images with extreme aspect ratios, consider using [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) instead. The example below shows how to customize augmentations for classification training.
+
     ```python
     import torch
     import torchvision.transforms as T
 
+    from ultralytics import YOLO
     from ultralytics.data.dataset import ClassificationDataset
     from ultralytics.models.yolo.classify import ClassificationTrainer
 
 
     class CustomizedDataset(ClassificationDataset):
+        """A customized dataset class for image classification with enhanced data augmentation transforms."""
+
         def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
+            """Initialize a customized classification dataset with enhanced data augmentation transforms."""
             super().__init__(root, args, augment, prefix)
             train_transforms = T.Compose(
                 [
@@ -110,12 +115,13 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
 
 
     class CustomizedTrainer(ClassificationTrainer):
+        """A customized trainer class for YOLO classification models with enhanced dataset handling."""
+
         def build_dataset(self, img_path: str, mode: str = "train", batch=None):
+            """Build a customized dataset for classification training or validation."""
             return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
 
 
-    from ultralytics import YOLO
-
     model = YOLO("yolo11n-cls.pt")
     model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
     ```
Discard
@@ -1,6 +1,6 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
-__version__ = "8.3.153"
+__version__ = "8.3.154"
 
 import os
 
Discard
@@ -82,7 +82,6 @@ class BaseValidator:
         update_metrics: Update metrics based on predictions and batch.
         finalize_metrics: Finalize and return all metrics.
         get_stats: Return statistics about the model's performance.
-        check_stats: Check statistics.
         print_results: Print the results of the model's predictions.
         get_desc: Get description of the YOLO model.
         on_plot: Register plots for visualization.
@@ -226,7 +225,6 @@ class BaseValidator:
 
             self.run_callbacks("on_val_batch_end")
         stats = self.get_stats()
-        self.check_stats(stats)
         self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
         self.finalize_metrics()
         self.print_results()
@@ -334,10 +332,6 @@ class BaseValidator:
         """Return statistics about the model's performance."""
         return {}
 
-    def check_stats(self, stats):
-        """Check statistics."""
-        pass
-
     def print_results(self):
         """Print the results of the model's predictions."""
         pass
Discard
@@ -1,7 +1,6 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
 from ultralytics.models.yolo.segment import SegmentationValidator
-from ultralytics.utils.metrics import SegmentMetrics
 
 
 class FastSAMValidator(SegmentationValidator):
@@ -39,4 +38,3 @@ class FastSAMValidator(SegmentationValidator):
         super().__init__(dataloader, save_dir, args, _callbacks)
         self.args.task = "segment"
         self.args.plots = False  # disable ConfusionMatrix and other plots to avoid errors
-        self.metrics = SegmentMetrics(save_dir=self.save_dir)
Discard
@@ -1,5 +1,7 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
+from typing import Any, Dict, List, Tuple, Union
+
 import torch
 
 from ultralytics.data import YOLODataset
@@ -151,15 +153,21 @@ class RTDETRValidator(DetectionValidator):
             data=self.data,
         )
 
-    def postprocess(self, preds):
+    def postprocess(
+        self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
+    ) -> List[Dict[str, torch.Tensor]]:
         """
         Apply Non-maximum suppression to prediction outputs.
 
         Args:
-            preds (list | tuple | torch.Tensor): Raw predictions from the model.
+            preds (torch.Tensor | List | Tuple): Raw predictions from the model. If tensor, should have shape
+                (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
 
         Returns:
-            (list[torch.Tensor]): List of processed predictions for each image in batch.
+            (List[Dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
+                - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
+                - 'conf': Tensor of shape (N,) with confidence scores
+                - 'cls': Tensor of shape (N,) with class indices
         """
         if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
             preds = [preds, None]
@@ -176,18 +184,19 @@ class RTDETRValidator(DetectionValidator):
             pred = pred[score.argsort(descending=True)]
             outputs[i] = pred[score > self.args.conf]
 
-        return outputs
+        return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
 
-    def _prepare_batch(self, si, batch):
+    def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         Prepare a batch for validation by applying necessary transformations.
 
         Args:
             si (int): Batch index.
-            batch (dict): Batch data containing images and annotations.
+            batch (Dict[str, Any]): Batch data containing images and annotations.
 
         Returns:
-            (dict): Prepared batch with transformed annotations.
+            (Dict[str, Any]): Prepared batch with transformed annotations containing cls, bboxes,
+                ori_shape, imgsz, and ratio_pad.
         """
         idx = batch["batch_idx"] == si
         cls = batch["cls"][idx].squeeze(-1)
@@ -199,20 +208,23 @@ class RTDETRValidator(DetectionValidator):
             bbox = ops.xywh2xyxy(bbox)  # target boxes
             bbox[..., [0, 2]] *= ori_shape[1]  # native-space pred
             bbox[..., [1, 3]] *= ori_shape[0]  # native-space pred
-        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+        return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
 
-    def _prepare_pred(self, pred, pbatch):
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         Prepare predictions by scaling bounding boxes to original image dimensions.
 
         Args:
-            pred (torch.Tensor): Raw predictions.
-            pbatch (dict): Prepared batch information.
+            pred (Dict[str, torch.Tensor]): Raw predictions containing 'cls', 'bboxes', and 'conf'.
+            pbatch (Dict[str, torch.Tensor]): Prepared batch information containing 'ori_shape' and other metadata.
 
         Returns:
-            (torch.Tensor): Predictions scaled to original image dimensions.
+            (Dict[str, torch.Tensor]): Predictions scaled to original image dimensions.
         """
-        predn = pred.clone()
-        predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz  # native-space pred
-        predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz  # native-space pred
-        return predn.float()
+        cls = pred["cls"]
+        if self.args.single_cls:
+            cls *= 0
+        bboxes = pred["bboxes"].clone()
+        bboxes[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz  # native-space pred
+        bboxes[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz  # native-space pred
+        return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
Discard
@@ -1,5 +1,8 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
 import torch
 
 from ultralytics.data import ClassificationDataset, build_dataloader
@@ -48,7 +51,7 @@ class ClassificationValidator(BaseValidator):
         Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
     """
 
-    def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
+    def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
         """
         Initialize ClassificationValidator with dataloader, save directory, and other parameters.
 
@@ -70,28 +73,26 @@ class ClassificationValidator(BaseValidator):
         self.args.task = "classify"
         self.metrics = ClassifyMetrics()
 
-    def get_desc(self):
+    def get_desc(self) -> str:
         """Return a formatted string summarizing classification metrics."""
         return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
 
-    def init_metrics(self, model):
+    def init_metrics(self, model: torch.nn.Module) -> None:
         """Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
         self.names = model.names
         self.nc = len(model.names)
-        self.confusion_matrix = ConfusionMatrix(
-            nc=self.nc, conf=self.args.conf, names=self.names.values(), task="classify"
-        )
         self.pred = []
         self.targets = []
+        self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
 
-    def preprocess(self, batch):
+    def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
         """Preprocess input batch by moving data to device and converting to appropriate dtype."""
         batch["img"] = batch["img"].to(self.device, non_blocking=True)
         batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
         batch["cls"] = batch["cls"].to(self.device)
         return batch
 
-    def update_metrics(self, preds, batch):
+    def update_metrics(self, preds: torch.Tensor, batch: Dict[str, Any]) -> None:
         """
         Update running metrics with model predictions and batch targets.
 
@@ -127,23 +128,23 @@ class ClassificationValidator(BaseValidator):
             for normalize in True, False:
                 self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
         self.metrics.speed = self.speed
-        self.metrics.confusion_matrix = self.confusion_matrix
         self.metrics.save_dir = self.save_dir
+        self.metrics.confusion_matrix = self.confusion_matrix
 
-    def postprocess(self, preds):
+    def postprocess(self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> torch.Tensor:
         """Extract the primary prediction from model output if it's in a list or tuple format."""
         return preds[0] if isinstance(preds, (list, tuple)) else preds
 
-    def get_stats(self):
+    def get_stats(self) -> Dict[str, float]:
         """Calculate and return a dictionary of metrics by processing targets and predictions."""
         self.metrics.process(self.targets, self.pred)
         return self.metrics.results_dict
 
-    def build_dataset(self, img_path):
+    def build_dataset(self, img_path: str) -> ClassificationDataset:
         """Create a ClassificationDataset instance for validation."""
         return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
 
-    def get_dataloader(self, dataset_path, batch_size):
+    def get_dataloader(self, dataset_path: Union[Path, str], batch_size: int) -> torch.utils.data.DataLoader:
         """
         Build and return a data loader for classification validation.
 
@@ -157,17 +158,17 @@ class ClassificationValidator(BaseValidator):
         dataset = self.build_dataset(dataset_path)
         return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
 
-    def print_results(self):
+    def print_results(self) -> None:
         """Print evaluation metrics for the classification model."""
         pf = "%22s" + "%11.3g" * len(self.metrics.keys)  # print format
         LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
 
-    def plot_val_samples(self, batch, ni):
+    def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
         """
         Plot validation image samples with their ground truth labels.
 
         Args:
-            batch (dict): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
+            batch (Dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
             ni (int): Batch index used for naming the output file.
 
         Examples:
@@ -175,21 +176,20 @@ class ClassificationValidator(BaseValidator):
             >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
             >>> validator.plot_val_samples(batch, 0)
         """
+        batch["batch_idx"] = torch.arange(len(batch["img"]))  # add batch index for plotting
         plot_images(
-            images=batch["img"],
-            batch_idx=torch.arange(len(batch["img"])),
-            cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
+            labels=batch,
             fname=self.save_dir / f"val_batch{ni}_labels.jpg",
             names=self.names,
             on_plot=self.on_plot,
         )
 
-    def plot_predictions(self, batch, preds, ni):
+    def plot_predictions(self, batch: Dict[str, Any], preds: torch.Tensor, ni: int) -> None:
         """
         Plot images with their predicted class labels and save the visualization.
 
         Args:
-            batch (dict): Batch data containing images and other information.
+            batch (Dict[str, Any]): Batch data containing images and other information.
             preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
             ni (int): Batch index used for naming the output file.
 
@@ -199,10 +199,13 @@ class ClassificationValidator(BaseValidator):
             >>> preds = torch.rand(16, 10)  # 16 images, 10 classes
             >>> validator.plot_predictions(batch, preds, 0)
         """
-        plot_images(
-            batch["img"],
+        batched_preds = dict(
+            img=batch["img"],
             batch_idx=torch.arange(len(batch["img"])),
             cls=torch.argmax(preds, dim=1),
+        )
+        plot_images(
+            batched_preds,
             fname=self.save_dir / f"val_batch{ni}_pred.jpg",
             names=self.names,
             on_plot=self.on_plot,
Discard
@@ -3,7 +3,7 @@
 import math
 import random
 from copy import copy
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
 
 import numpy as np
 import torch.nn as nn
@@ -178,19 +178,16 @@ class DetectionTrainer(BaseTrainer):
             "Size",
         )
 
-    def plot_training_samples(self, batch: Dict, ni: int):
+    def plot_training_samples(self, batch: Dict[str, Any], ni: int) -> None:
         """
         Plot training samples with their annotations.
 
         Args:
-            batch (Dict): Dictionary containing batch data.
+            batch (Dict[str, Any]): Dictionary containing batch data.
             ni (int): Number of iterations.
         """
         plot_images(
-            images=batch["img"],
-            batch_idx=batch["batch_idx"],
-            cls=batch["cls"].squeeze(-1),
-            bboxes=batch["bboxes"],
+            labels=batch,
             paths=batch["im_file"],
             fname=self.save_dir / f"train_batch{ni}.jpg",
             on_plot=self.on_plot,
Discard
@@ -12,7 +12,7 @@ from ultralytics.engine.validator import BaseValidator
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils.checks import check_requirements
 from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
-from ultralytics.utils.plotting import output_to_target, plot_images
+from ultralytics.utils.plotting import plot_images
 
 
 class DetectionValidator(BaseValidator):
@@ -23,8 +23,6 @@ class DetectionValidator(BaseValidator):
     prediction processing, and visualization of results.
 
     Attributes:
-        nt_per_class (np.ndarray): Number of targets per class.
-        nt_per_image (np.ndarray): Number of targets per image.
         is_coco (bool): Whether the dataset is COCO.
         is_lvis (bool): Whether the dataset is LVIS.
         class_map (List[int]): Mapping from model class indices to dataset class indices.
@@ -53,15 +51,13 @@ class DetectionValidator(BaseValidator):
             _callbacks (List[Any], optional): List of callback functions.
         """
         super().__init__(dataloader, save_dir, args, _callbacks)
-        self.nt_per_class = None
-        self.nt_per_image = None
         self.is_coco = False
         self.is_lvis = False
         self.class_map = None
         self.args.task = "detect"
-        self.metrics = DetMetrics(save_dir=self.save_dir)
         self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95
         self.niou = self.iouv.numel()
+        self.metrics = DetMetrics()
 
     def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
@@ -99,18 +95,16 @@ class DetectionValidator(BaseValidator):
         self.names = model.names
         self.nc = len(model.names)
         self.end2end = getattr(model, "end2end", False)
-        self.metrics.names = self.names
-        self.metrics.plot = self.args.plots
-        self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, names=self.names.values())
         self.seen = 0
         self.jdict = []
-        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+        self.metrics.names = self.names
+        self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
 
     def get_desc(self) -> str:
         """Return a formatted string summarizing class metrics of YOLO model."""
         return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
 
-    def postprocess(self, preds: torch.Tensor) -> List[torch.Tensor]:
+    def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:
         """
         Apply Non-maximum suppression to prediction outputs.
 
@@ -118,9 +112,10 @@ class DetectionValidator(BaseValidator):
             preds (torch.Tensor): Raw predictions from the model.
 
         Returns:
-            (List[torch.Tensor]): Processed predictions after NMS.
+            (List[Dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
+                'bboxes', 'conf', 'cls', and 'extra' tensors.
         """
-        return ops.non_max_suppression(
+        outputs = ops.non_max_suppression(
             preds,
             self.args.conf,
             self.args.iou,
@@ -131,6 +126,7 @@ class DetectionValidator(BaseValidator):
             end2end=self.end2end,
             rotated=self.args.task == "obb",
         )
+        return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
 
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
@@ -152,68 +148,60 @@ class DetectionValidator(BaseValidator):
         if len(cls):
             bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # target boxes
             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # native-space labels
-        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+        return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
 
-    def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         Prepare predictions for evaluation against ground truth.
 
         Args:
-            pred (torch.Tensor): Model predictions.
+            pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
             pbatch (Dict[str, Any]): Prepared batch information.
 
         Returns:
-            (torch.Tensor): Prepared predictions in native space.
-        """
-        predn = pred.clone()
-        ops.scale_boxes(
-            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+            (Dict[str, torch.Tensor]): Prepared predictions in native space.
+        """
+        cls = pred["cls"]
+        if self.args.single_cls:
+            cls *= 0
+        # predn = pred.clone()
+        bboxes = ops.scale_boxes(
+            pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
         )  # native-space pred
-        return predn
+        return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
 
-    def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
+    def update_metrics(self, preds: List[Dict[str, torch.Tensor]], batch: Dict[str, Any]) -> None:
         """
         Update metrics with new predictions and ground truth.
 
         Args:
-            preds (List[torch.Tensor]): List of predictions from the model.
+            preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
             batch (Dict[str, Any]): Batch data containing ground truth.
         """
         for si, pred in enumerate(preds):
             self.seen += 1
-            npr = len(pred)
-            stat = dict(
-                conf=torch.zeros(0, device=self.device),
-                pred_cls=torch.zeros(0, device=self.device),
-                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-            )
             pbatch = self._prepare_batch(si, batch)
-            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
-            nl = len(cls)
-            stat["target_cls"] = cls
-            stat["target_img"] = cls.unique()
-            if npr == 0:
-                if nl:
-                    for k in self.stats.keys():
-                        self.stats[k].append(stat[k])
-                    if self.args.plots:
-                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
-                continue
-
-            # Predictions
-            if self.args.single_cls:
-                pred[:, 5] = 0
             predn = self._prepare_pred(pred, pbatch)
-            stat["conf"] = predn[:, 4]
-            stat["pred_cls"] = predn[:, 5]
 
+            cls = pbatch["cls"].cpu().numpy()
+            no_pred = len(predn["cls"]) == 0
+            if no_pred and len(cls) == 0:
+                continue
+            self.metrics.update_stats(
+                {
+                    **self._process_batch(predn, pbatch),
+                    "target_cls": cls,
+                    "target_img": np.unique(cls),
+                    "conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
+                    "pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
+                }
+            )
             # Evaluate
-            if nl:
-                stat["tp"] = self._process_batch(predn, bbox, cls)
             if self.args.plots:
-                self.confusion_matrix.process_batch(predn, bbox, cls)
-            for k in self.stats.keys():
-                self.stats[k].append(stat[k])
+                self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
+
+            if no_pred:
+                continue
 
             # Save
             if self.args.save_json:
@@ -241,44 +229,45 @@ class DetectionValidator(BaseValidator):
         Returns:
             (Dict[str, Any]): Dictionary containing metrics results.
         """
-        stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy
-        self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
-        self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
-        stats.pop("target_img", None)
-        if len(stats):
-            self.metrics.process(**stats, on_plot=self.on_plot)
+        self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
+        self.metrics.clear_stats()
         return self.metrics.results_dict
 
     def print_results(self) -> None:
         """Print training/validation set metrics per class."""
         pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format
-        LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
-        if self.nt_per_class.sum() == 0:
+        LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
+        if self.metrics.nt_per_class.sum() == 0:
             LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
 
         # Print results per class
-        if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
+        if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
             for i, c in enumerate(self.metrics.ap_class_index):
                 LOGGER.info(
-                    pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
+                    pf
+                    % (
+                        self.names[c],
+                        self.metrics.nt_per_image[c],
+                        self.metrics.nt_per_class[c],
+                        *self.metrics.class_result(i),
+                    )
                 )
 
-    def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
         """
         Return correct prediction matrix.
 
         Args:
-            detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
-                (x1, y1, x2, y2, conf, class).
-            gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
-                bounding box is of the format: (x1, y1, x2, y2).
-            gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
+            preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
+            batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
 
         Returns:
-            (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
+            (Dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
         """
-        iou = box_iou(gt_bboxes, detections[:, :4])
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
+        if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
+            return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
+        iou = box_iou(batch["bboxes"], preds["bboxes"])
+        return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
 
     def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None) -> torch.utils.data.Dataset:
         """
@@ -317,42 +306,50 @@ class DetectionValidator(BaseValidator):
             ni (int): Batch index.
         """
         plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
+            labels=batch,
             paths=batch["im_file"],
             fname=self.save_dir / f"val_batch{ni}_labels.jpg",
             names=self.names,
             on_plot=self.on_plot,
         )
 
-    def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
+    def plot_predictions(
+        self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None
+    ) -> None:
         """
         Plot predicted bounding boxes on input images and save the result.
 
         Args:
             batch (Dict[str, Any]): Batch containing images and annotations.
-            preds (List[torch.Tensor]): List of predictions from the model.
+            preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
             ni (int): Batch index.
-        """
+            max_det (Optional[int]): Maximum number of detections to plot.
+        """
+        # TODO: optimize this
+        for i, pred in enumerate(preds):
+            pred["batch_idx"] = torch.ones_like(pred["conf"]) * i  # add batch index to predictions
+        keys = preds[0].keys()
+        max_det = max_det or self.args.max_det
+        batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
+        # TODO: fix this
+        batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4])  # convert to xywh format
         plot_images(
-            batch["img"],
-            *output_to_target(preds, max_det=self.args.max_det),
+            images=batch["img"],
+            labels=batched_preds,
             paths=batch["im_file"],
             fname=self.save_dir / f"val_batch{ni}_pred.jpg",
             names=self.names,
             on_plot=self.on_plot,
         )  # pred
 
-    def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
+    def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         Save YOLO detections to a txt file in normalized coordinates in a specific format.
 
         Args:
-            predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
+            predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
             save_conf (bool): Whether to save confidence scores.
-            shape (Tuple[int, int]): Shape of the original image.
+            shape (Tuple[int, int]): Shape of the original image (height, width).
             file (Path): File path to save the detections.
         """
         from ultralytics.engine.results import Results
@@ -361,28 +358,29 @@ class DetectionValidator(BaseValidator):
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             names=self.names,
-            boxes=predn[:, :6],
+            boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
         ).save_txt(file, save_conf=save_conf)
 
-    def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
+    def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
         """
         Serialize YOLO predictions to COCO json format.
 
         Args:
-            predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
+            predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
+                with bounding box coordinates, confidence scores, and class predictions.
             filename (str): Image filename.
         """
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
-        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box = ops.xyxy2xywh(predn["bboxes"])  # xywh
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
-        for p, b in zip(predn.tolist(), box.tolist()):
+        for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
             self.jdict.append(
                 {
                     "image_id": image_id,
-                    "category_id": self.class_map[int(p[5])],
+                    "category_id": self.class_map[int(c)],
                     "bbox": [round(x, 3) for x in b],
-                    "score": round(p[4], 5),
+                    "score": round(s, 5),
                 }
             )
 
Discard
@@ -3,12 +3,12 @@
 from pathlib import Path
 from typing import Any, Dict, List, Tuple, Union
 
+import numpy as np
 import torch
 
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils.metrics import OBBMetrics, batch_probiou
-from ultralytics.utils.plotting import output_to_rotated_target, plot_images
 
 
 class OBBValidator(DetectionValidator):
@@ -55,7 +55,7 @@ class OBBValidator(DetectionValidator):
         """
         super().__init__(dataloader, save_dir, args, _callbacks)
         self.args.task = "obb"
-        self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
+        self.metrics = OBBMetrics()
 
     def init_metrics(self, model: torch.nn.Module) -> None:
         """
@@ -68,20 +68,20 @@ class OBBValidator(DetectionValidator):
         val = self.data.get(self.args.split, "")  # validation path
         self.is_dota = isinstance(val, str) and "DOTA" in val  # check if dataset is DOTA format
 
-    def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
         """
         Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
 
         Args:
-            detections (torch.Tensor): Detected bounding boxes and associated data with shape (N, 7) where each
-                detection is represented as (x1, y1, x2, y2, conf, class, angle).
-            gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (M, 5) where each box is represented
-                as (x1, y1, x2, y2, angle).
-            gt_cls (torch.Tensor): Class labels for the ground truth bounding boxes with shape (M,).
+            preds (Dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
+                class labels and bounding boxes.
+            batch (Dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
+                class labels and bounding boxes.
 
         Returns:
-            (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU levels for each
-                detection, indicating the accuracy of predictions compared to the ground truth.
+            (Dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
+                array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
+                of predictions compared to the ground truth.
 
         Examples:
             >>> detections = torch.rand(100, 7)  # 100 sample detections
@@ -89,10 +89,25 @@ class OBBValidator(DetectionValidator):
             >>> gt_cls = torch.randint(0, 5, (50,))  # 50 ground truth class labels
             >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
         """
-        iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
+        if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
+            return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
+        iou = batch_probiou(batch["bboxes"], preds["bboxes"])
+        return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
 
-    def _prepare_batch(self, si: int, batch: Dict) -> Dict:
+    def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:
+        """
+        Args:
+            preds (torch.Tensor): Raw predictions from the model.
+
+        Returns:
+            (List[Dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
+        """
+        preds = super().postprocess(preds)
+        for pred in preds:
+            pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1)  # concatenate angle
+        return preds
+
+    def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         Prepare batch data for OBB validation with proper scaling and formatting.
 
@@ -118,9 +133,9 @@ class OBBValidator(DetectionValidator):
         if len(cls):
             bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]])  # target boxes
             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True)  # native-space labels
-        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+        return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
 
-    def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         Prepare predictions by scaling bounding boxes to original image dimensions.
 
@@ -128,20 +143,22 @@ class OBBValidator(DetectionValidator):
         input dimensions to the original image dimensions using the provided batch information.
 
         Args:
-            pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
+            pred (Dict[str, torch.Tensor]): Prediction dictionary containing bounding box coordinates and other information.
             pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
                 - imgsz (tuple): Model input image size.
                 - ori_shape (tuple): Original image shape.
                 - ratio_pad (tuple): Ratio and padding information for scaling.
 
         Returns:
-            (torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
+            (Dict[str, torch.Tensor]): Scaled prediction dictionary with bounding boxes in original image dimensions.
         """
-        predn = pred.clone()
-        ops.scale_boxes(
-            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
+        cls = pred["cls"]
+        if self.args.single_cls:
+            cls *= 0
+        bboxes = ops.scale_boxes(
+            pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
         )  # native-space pred
-        return predn
+        return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
 
     def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
         """
@@ -158,22 +175,18 @@ class OBBValidator(DetectionValidator):
             >>> preds = [torch.rand(10, 7)]  # Example predictions for one image
             >>> validator.plot_predictions(batch, preds, 0)
         """
-        plot_images(
-            batch["img"],
-            *output_to_rotated_target(preds, max_det=self.args.max_det),
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )  # pred
+        for p in preds:
+            # TODO: fix this duplicated `xywh2xyxy`
+            p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4])  # convert to xyxy format for plotting
+        super().plot_predictions(batch, preds, ni)  # plot bboxes
 
-    def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]) -> None:
+    def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: Union[str, Path]) -> None:
         """
         Convert YOLO predictions to COCO JSON format with rotated bounding box information.
 
         Args:
-            predn (torch.Tensor): Prediction tensor containing bounding box coordinates, confidence scores,
-                class predictions, and rotation angles with shape (N, 6+) where the last column is the angle.
+            predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
+                with bounding box coordinates, confidence scores, and class predictions.
             filename (str | Path): Path to the image file for which predictions are being processed.
 
         Notes:
@@ -183,22 +196,20 @@ class OBBValidator(DetectionValidator):
         """
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
-        rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
+        rbox = predn["bboxes"]
         poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
-        for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
+        for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
             self.jdict.append(
                 {
                     "image_id": image_id,
-                    "category_id": self.class_map[int(predn[i, 5].item())],
-                    "score": round(predn[i, 4].item(), 5),
+                    "category_id": self.class_map[int(c)],
+                    "score": round(s, 5),
                     "rbox": [round(x, 3) for x in r],
                     "poly": [round(x, 3) for x in b],
                 }
             )
 
-    def save_one_txt(
-        self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]
-    ) -> None:
+    def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         Save YOLO OBB detections to a text file in normalized coordinates.
 
@@ -207,7 +218,7 @@ class OBBValidator(DetectionValidator):
                 class predictions, and angles in format (x, y, w, h, conf, cls, angle).
             save_conf (bool): Whether to save confidence scores in the text file.
             shape (Tuple[int, int]): Original image shape in format (height, width).
-            file (Path | str): Output file path to save detections.
+            file (Path): Output file path to save detections.
 
         Examples:
             >>> validator = OBBValidator()
@@ -218,14 +229,11 @@ class OBBValidator(DetectionValidator):
 
         from ultralytics.engine.results import Results
 
-        rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
-        # xywh, r, conf, cls
-        obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
         Results(
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             names=self.names,
-            obb=obb,
+            obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
         ).save_txt(file, save_conf=save_conf)
 
     def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
Discard
@@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
 from ultralytics.models import yolo
 from ultralytics.nn.tasks import PoseModel
 from ultralytics.utils import DEFAULT_CFG, LOGGER
-from ultralytics.utils.plotting import plot_images, plot_results
+from ultralytics.utils.plotting import plot_results
 
 
 class PoseTrainer(yolo.detect.DetectionTrainer):
@@ -108,40 +108,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
         )
 
-    def plot_training_samples(self, batch: Dict[str, Any], ni: int):
-        """
-        Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
-
-        Args:
-            batch (dict): Dictionary containing batch data with the following keys:
-                - img (torch.Tensor): Batch of images
-                - keypoints (torch.Tensor): Keypoints coordinates for pose estimation
-                - cls (torch.Tensor): Class labels
-                - bboxes (torch.Tensor): Bounding box coordinates
-                - im_file (list): List of image file paths
-                - batch_idx (torch.Tensor): Batch indices for each instance
-            ni (int): Current training iteration number used for filename
-
-        The function saves the plotted batch as an image in the trainer's save directory with the filename
-        'train_batch{ni}.jpg', where ni is the iteration number.
-        """
-        images = batch["img"]
-        kpts = batch["keypoints"]
-        cls = batch["cls"].squeeze(-1)
-        bboxes = batch["bboxes"]
-        paths = batch["im_file"]
-        batch_idx = batch["batch_idx"]
-        plot_images(
-            images,
-            batch_idx,
-            cls,
-            bboxes,
-            kpts=kpts,
-            paths=paths,
-            fname=self.save_dir / f"train_batch{ni}.jpg",
-            on_plot=self.on_plot,
-        )
-
     def plot_metrics(self):
         """Plot training/validation metrics."""
         plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png
Discard
@@ -1,7 +1,7 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Tuple
 
 import numpy as np
 import torch
@@ -9,8 +9,7 @@ import torch
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils.checks import check_requirements
-from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
-from ultralytics.utils.plotting import output_to_target, plot_images
+from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
 
 
 class PoseValidator(DetectionValidator):
@@ -33,7 +32,6 @@ class PoseValidator(DetectionValidator):
         _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
             dimensions.
         _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
-        update_metrics: Update metrics with new predictions and ground truth data.
         _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
             detections and ground truth.
         plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
@@ -77,7 +75,7 @@ class PoseValidator(DetectionValidator):
         self.sigma = None
         self.kpt_shape = None
         self.args.task = "pose"
-        self.metrics = PoseMetrics(save_dir=self.save_dir)
+        self.metrics = PoseMetrics()
         if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
             LOGGER.warning(
                 "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
@@ -118,7 +116,36 @@ class PoseValidator(DetectionValidator):
         is_pose = self.kpt_shape == [17, 3]
         nkpt = self.kpt_shape[0]
         self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
-        self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+
+    def postprocess(self, preds: torch.Tensor) -> Dict[str, torch.Tensor]:
+        """
+        Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
+
+        This method extends the parent class postprocessing by extracting keypoints from the 'extra'
+        field of predictions and reshaping them according to the keypoint shape configuration.
+        The keypoints are reshaped from a flattened format to the proper dimensional structure
+        (typically [N, 17, 3] for COCO pose format).
+
+        Args:
+            preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
+                bounding boxes, confidence scores, class predictions, and keypoint data.
+
+        Returns:
+            (Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
+                - 'bboxes': Bounding box coordinates
+                - 'conf': Confidence scores
+                - 'cls': Class predictions
+                - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
+
+        Note:
+            If no keypoints are present in a prediction (empty keypoints), that prediction
+            is skipped and continues to the next one. The keypoints are extracted from the
+            'extra' field which contains additional task-specific data beyond basic detection.
+        """
+        preds = super().postprocess(preds)
+        for pred in preds:
+            pred["keypoints"] = pred.pop("extra").reshape(-1, *self.kpt_shape)  # remove extra if exists
+        return preds
 
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
@@ -142,10 +169,10 @@ class PoseValidator(DetectionValidator):
         kpts[..., 0] *= w
         kpts[..., 1] *= h
         kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
-        pbatch["kpts"] = kpts
+        pbatch["keypoints"] = kpts
         return pbatch
 
-    def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
+    def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:
         """
         Prepare and scale keypoints in predictions for pose processing.
 
@@ -154,189 +181,59 @@ class PoseValidator(DetectionValidator):
         to match the original image dimensions.
 
         Args:
-            pred (torch.Tensor): Raw prediction tensor from the model.
+            pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
             pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
                 - imgsz: Image size used for inference
                 - ori_shape: Original image shape
                 - ratio_pad: Ratio and padding information for coordinate scaling
 
         Returns:
-            predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
-            pred_kpts (torch.Tensor): Predicted keypoints scaled to original image dimensions.
+            (Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.
         """
         predn = super()._prepare_pred(pred, pbatch)
-        nk = pbatch["kpts"].shape[1]
-        pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
-        ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
-        return predn, pred_kpts
-
-    def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
-        """
-        Update metrics with new predictions and ground truth data.
-
-        This method processes each prediction, compares it with ground truth, and updates various statistics
-        for performance evaluation.
+        predn["keypoints"] = ops.scale_coords(
+            pbatch["imgsz"], pred.get("keypoints").clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+        )
+        return predn
 
-        Args:
-            preds (List[torch.Tensor]): List of prediction tensors from the model.
-            batch (Dict[str, Any]): Batch data containing images and ground truth annotations.
-        """
-        for si, pred in enumerate(preds):
-            self.seen += 1
-            npr = len(pred)
-            stat = dict(
-                conf=torch.zeros(0, device=self.device),
-                pred_cls=torch.zeros(0, device=self.device),
-                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-                tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-            )
-            pbatch = self._prepare_batch(si, batch)
-            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
-            nl = len(cls)
-            stat["target_cls"] = cls
-            stat["target_img"] = cls.unique()
-            if npr == 0:
-                if nl:
-                    for k in self.stats.keys():
-                        self.stats[k].append(stat[k])
-                    if self.args.plots:
-                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
-                continue
-
-            # Predictions
-            if self.args.single_cls:
-                pred[:, 5] = 0
-            predn, pred_kpts = self._prepare_pred(pred, pbatch)
-            stat["conf"] = predn[:, 4]
-            stat["pred_cls"] = predn[:, 5]
-
-            # Evaluate
-            if nl:
-                stat["tp"] = self._process_batch(predn, bbox, cls)
-                stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
-            if self.args.plots:
-                self.confusion_matrix.process_batch(predn, bbox, cls)
-
-            for k in self.stats.keys():
-                self.stats[k].append(stat[k])
-
-            # Save
-            if self.args.save_json:
-                self.pred_to_json(predn, batch["im_file"][si])
-            if self.args.save_txt:
-                self.save_one_txt(
-                    predn,
-                    pred_kpts,
-                    self.args.save_conf,
-                    pbatch["ori_shape"],
-                    self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
-                )
-
-    def _process_batch(
-        self,
-        detections: torch.Tensor,
-        gt_bboxes: torch.Tensor,
-        gt_cls: torch.Tensor,
-        pred_kpts: Optional[torch.Tensor] = None,
-        gt_kpts: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
         """
         Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
 
         Args:
-            detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
-                detection is of the format (x1, y1, x2, y2, conf, class).
-            gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
-                box is of the format (x1, y1, x2, y2).
-            gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
-            pred_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing predicted keypoints, where
-                51 corresponds to 17 keypoints each having 3 values.
-            gt_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing ground truth keypoints.
+            preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
+                and 'keypoints' for keypoint predictions.
+            batch (Dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
+                'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
 
         Returns:
-            (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
-                where N is the number of detections.
+            (Dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
+                true positives across 10 IoU levels.
 
         Notes:
             `0.53` scale factor used in area computation is referenced from
             https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
         """
-        if pred_kpts is not None and gt_kpts is not None:
+        tp = super()._process_batch(preds, batch)
+        gt_cls = batch["cls"]
+        if len(gt_cls) == 0 or len(preds["cls"]) == 0:
+            tp_p = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
+        else:
             # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
-            area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
-            iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
-        else:  # boxes
-            iou = box_iou(gt_bboxes, detections[:, :4])
+            area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
+            iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
+            tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
+        tp.update({"tp_p": tp_p})  # update tp with kpts IoU
+        return tp
 
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
-
-    def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
-        """
-        Plot and save validation set samples with ground truth bounding boxes and keypoints.
-
-        Args:
-            batch (Dict[str, Any]): Dictionary containing batch data with keys:
-                - img (torch.Tensor): Batch of images
-                - batch_idx (torch.Tensor): Batch indices for each image
-                - cls (torch.Tensor): Class labels
-                - bboxes (torch.Tensor): Bounding box coordinates
-                - keypoints (torch.Tensor): Keypoint coordinates
-                - im_file (list): List of image file paths
-            ni (int): Batch index used for naming the output file
-        """
-        plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
-            kpts=batch["keypoints"],
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )
-
-    def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
-        """
-        Plot and save model predictions with bounding boxes and keypoints.
-
-        Args:
-            batch (Dict[str, Any]): Dictionary containing batch data including images, file paths, and other metadata.
-            preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
-                confidence scores, class predictions, and keypoints.
-            ni (int): Batch index used for naming the output file.
-
-        The function extracts keypoints from predictions, converts predictions to target format, and plots them
-        on the input images. The resulting visualization is saved to the specified save directory.
-        """
-        pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
-        plot_images(
-            batch["img"],
-            *output_to_target(preds, max_det=self.args.max_det),
-            kpts=pred_kpts,
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )  # pred
-
-    def save_one_txt(
-        self,
-        predn: torch.Tensor,
-        pred_kpts: torch.Tensor,
-        save_conf: bool,
-        shape: Tuple[int, int],
-        file: Path,
-    ) -> None:
+    def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         Save YOLO pose detections to a text file in normalized coordinates.
 
         Args:
-            predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
-            pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
-                and D is the dimension (typically 3 for x, y, visibility).
+            predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
             save_conf (bool): Whether to save confidence scores.
-            shape (tuple): Original image shape (height, width).
+            shape (Tuple[int, int]): Shape of the original image (height, width).
             file (Path): Output file path to save detections.
 
         Notes:
@@ -349,11 +246,11 @@ class PoseValidator(DetectionValidator):
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             names=self.names,
-            boxes=predn[:, :6],
-            keypoints=pred_kpts,
+            boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
+            keypoints=predn["keypoints"],
         ).save_txt(file, save_conf=save_conf)
 
-    def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
+    def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
         """
         Convert YOLO predictions to COCO JSON format.
 
@@ -361,10 +258,9 @@ class PoseValidator(DetectionValidator):
         to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
 
         Args:
-            predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
-                and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
-                keypoints dimension.
-            filename (str | Path): Path to the image file for which predictions are being processed.
+            predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
+                and 'keypoints' tensors.
+            filename (str): Path to the image file for which predictions are being processed.
 
         Notes:
             The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
@@ -373,16 +269,21 @@ class PoseValidator(DetectionValidator):
         """
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
-        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box = ops.xyxy2xywh(predn["bboxes"])  # xywh
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
-        for p, b in zip(predn.tolist(), box.tolist()):
+        for b, s, c, k in zip(
+            box.tolist(),
+            predn["conf"].tolist(),
+            predn["cls"].tolist(),
+            predn["keypoints"].flatten(1, 2).tolist(),
+        ):
             self.jdict.append(
                 {
                     "image_id": image_id,
-                    "category_id": self.class_map[int(p[5])],
+                    "category_id": self.class_map[int(c)],
                     "bbox": [round(x, 3) for x in b],
-                    "keypoints": p[6:],
-                    "score": round(p[4], 5),
+                    "keypoints": k,
+                    "score": round(s, 5),
                 }
             )
 
Discard
@@ -7,7 +7,7 @@ from typing import Dict, Optional, Union
 from ultralytics.models import yolo
 from ultralytics.nn.tasks import SegmentationModel
 from ultralytics.utils import DEFAULT_CFG, RANK
-from ultralytics.utils.plotting import plot_images, plot_results
+from ultralytics.utils.plotting import plot_results
 
 
 class SegmentationTrainer(yolo.detect.DetectionTrainer):
@@ -82,46 +82,6 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
         )
 
-    def plot_training_samples(self, batch: Dict, ni: int):
-        """
-        Plot training sample images with labels, bounding boxes, and masks.
-
-        This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
-        and segmentation masks, saving the result to a file for inspection and debugging.
-
-        Args:
-            batch (dict): Dictionary containing batch data with the following keys:
-                'img': Images tensor
-                'batch_idx': Batch indices for each box
-                'cls': Class labels tensor (squeezed to remove last dimension)
-                'bboxes': Bounding box coordinates tensor
-                'masks': Segmentation masks tensor
-                'im_file': List of image file paths
-            ni (int): Current training iteration number, used for naming the output file.
-
-        Examples:
-            >>> trainer = SegmentationTrainer()
-            >>> batch = {
-            ...     "img": torch.rand(16, 3, 640, 640),
-            ...     "batch_idx": torch.zeros(16),
-            ...     "cls": torch.randint(0, 80, (16, 1)),
-            ...     "bboxes": torch.rand(16, 4),
-            ...     "masks": torch.rand(16, 640, 640),
-            ...     "im_file": ["image1.jpg", "image2.jpg"],
-            ... }
-            >>> trainer.plot_training_samples(batch, ni=5)
-        """
-        plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
-            masks=batch["masks"],
-            paths=batch["im_file"],
-            fname=self.save_dir / f"train_batch{ni}.jpg",
-            on_plot=self.on_plot,
-        )
-
     def plot_metrics(self):
         """Plot training/validation metrics."""
         plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png
Discard
@@ -2,7 +2,7 @@
 
 from multiprocessing.pool import ThreadPool
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Tuple
 
 import numpy as np
 import torch
@@ -11,8 +11,7 @@ import torch.nn.functional as F
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.utils import LOGGER, NUM_THREADS, ops
 from ultralytics.utils.checks import check_requirements
-from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
-from ultralytics.utils.plotting import output_to_target, plot_images
+from ultralytics.utils.metrics import SegmentMetrics, mask_iou
 
 
 class SegmentationValidator(DetectionValidator):
@@ -47,10 +46,9 @@ class SegmentationValidator(DetectionValidator):
             _callbacks (list, optional): List of callback functions.
         """
         super().__init__(dataloader, save_dir, args, _callbacks)
-        self.plot_masks = None
         self.process = None
         self.args.task = "segment"
-        self.metrics = SegmentMetrics(save_dir=self.save_dir)
+        self.metrics = SegmentMetrics()
 
     def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
@@ -74,12 +72,10 @@ class SegmentationValidator(DetectionValidator):
             model (torch.nn.Module): Model to validate.
         """
         super().init_metrics(model)
-        self.plot_masks = []
         if self.args.save_json:
             check_requirements("pycocotools>=2.0.6")
         # More accurate vs faster
         self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
-        self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
 
     def get_desc(self) -> str:
         """Return a formatted description of evaluation metrics."""
@@ -97,7 +93,7 @@ class SegmentationValidator(DetectionValidator):
             "mAP50-95)",
         )
 
-    def postprocess(self, preds: List[torch.Tensor]) -> Tuple[List[torch.Tensor], torch.Tensor]:
+    def postprocess(self, preds: List[torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
         """
         Post-process YOLO predictions and return output detections with proto.
 
@@ -105,12 +101,19 @@ class SegmentationValidator(DetectionValidator):
             preds (List[torch.Tensor]): Raw predictions from the model.
 
         Returns:
-            p (List[torch.Tensor]): Processed detection predictions.
-            proto (torch.Tensor): Prototype masks for segmentation.
+            List[Dict[str, torch.Tensor]]: Processed detection predictions with masks.
         """
-        p = super().postprocess(preds[0])
         proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
-        return p, proto
+        preds = super().postprocess(preds[0])
+        imgsz = [4 * x for x in proto.shape[2:]]  # get image size from proto
+        for i, pred in enumerate(preds):
+            coefficient = pred.pop("extra")
+            pred["masks"] = (
+                self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
+                if len(coefficient)
+                else torch.zeros((0, imgsz[0], imgsz[1]), dtype=torch.uint8, device=pred["bboxes"].device)
+            )
+        return preds
 
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
@@ -128,142 +131,56 @@ class SegmentationValidator(DetectionValidator):
         prepared_batch["masks"] = batch["masks"][midx]
         return prepared_batch
 
-    def _prepare_pred(
-        self, pred: torch.Tensor, pbatch: Dict[str, Any], proto: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         Prepare predictions for evaluation by processing bounding boxes and masks.
 
         Args:
-            pred (torch.Tensor): Raw predictions from the model.
+            pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
             pbatch (Dict[str, Any]): Prepared batch information.
-            proto (torch.Tensor): Prototype masks for segmentation.
 
         Returns:
-            predn (torch.Tensor): Processed bounding box predictions.
-            pred_masks (torch.Tensor): Processed mask predictions.
+            Dict[str, torch.Tensor]: Processed bounding box predictions.
         """
         predn = super()._prepare_pred(pred, pbatch)
-        pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
-        return predn, pred_masks
-
-    def update_metrics(self, preds: Tuple[List[torch.Tensor], torch.Tensor], batch: Dict[str, Any]) -> None:
-        """
-        Update metrics with the current batch predictions and targets.
-
-        Args:
-            preds (Tuple[List[torch.Tensor], torch.Tensor]): List of predictions from the model.
-            batch (Dict[str, Any]): Batch data containing ground truth.
-        """
-        for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
-            self.seen += 1
-            npr = len(pred)
-            stat = dict(
-                conf=torch.zeros(0, device=self.device),
-                pred_cls=torch.zeros(0, device=self.device),
-                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-                tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+        predn["masks"] = pred["masks"]
+        if self.args.save_json and len(predn["masks"]):
+            coco_masks = torch.as_tensor(pred["masks"], dtype=torch.uint8)
+            coco_masks = ops.scale_image(
+                coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
+                pbatch["ori_shape"],
+                ratio_pad=pbatch["ratio_pad"],
             )
-            pbatch = self._prepare_batch(si, batch)
-            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
-            nl = len(cls)
-            stat["target_cls"] = cls
-            stat["target_img"] = cls.unique()
-            if npr == 0:
-                if nl:
-                    for k in self.stats.keys():
-                        self.stats[k].append(stat[k])
-                    if self.args.plots:
-                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
-                continue
-
-            # Masks
-            gt_masks = pbatch.pop("masks")
-            # Predictions
-            if self.args.single_cls:
-                pred[:, 5] = 0
-            predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
-            stat["conf"] = predn[:, 4]
-            stat["pred_cls"] = predn[:, 5]
-
-            # Evaluate
-            if nl:
-                stat["tp"] = self._process_batch(predn, bbox, cls)
-                stat["tp_m"] = self._process_batch(
-                    predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
-                )
-            if self.args.plots:
-                self.confusion_matrix.process_batch(predn, bbox, cls)
-
-            for k in self.stats.keys():
-                self.stats[k].append(stat[k])
-
-            pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
-            if self.args.plots and self.batch_i < 3:
-                self.plot_masks.append(pred_masks[:50].cpu())  # Limit plotted items for speed
-                if pred_masks.shape[0] > 50:
-                    LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
-
-            # Save
-            if self.args.save_json:
-                self.pred_to_json(
-                    predn,
-                    batch["im_file"][si],
-                    ops.scale_image(
-                        pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
-                        pbatch["ori_shape"],
-                        ratio_pad=batch["ratio_pad"][si],
-                    ),
-                )
-            if self.args.save_txt:
-                self.save_one_txt(
-                    predn,
-                    pred_masks,
-                    self.args.save_conf,
-                    pbatch["ori_shape"],
-                    self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
-                )
-
-    def _process_batch(
-        self,
-        detections: torch.Tensor,
-        gt_bboxes: torch.Tensor,
-        gt_cls: torch.Tensor,
-        pred_masks: Optional[torch.Tensor] = None,
-        gt_masks: Optional[torch.Tensor] = None,
-        overlap: Optional[bool] = False,
-        masks: Optional[bool] = False,
-    ) -> torch.Tensor:
+            predn["coco_masks"] = coco_masks
+        return predn
+
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
         """
         Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
 
         Args:
-            detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
-                associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
-            gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
-                Each row is of the format [x1, y1, x2, y2].
-            gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
-            pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
-                match the ground truth masks.
-            gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
-            overlap (bool, optional): Flag indicating if overlapping masks should be considered.
-            masks (bool, optional): Flag indicating if the batch contains mask data.
+            preds (Dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
+            batch (Dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
 
         Returns:
-            (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
+            (Dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
 
         Notes:
             - If `masks` is True, the function computes IoU between predicted and ground truth masks.
             - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
 
         Examples:
-            >>> detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
-            >>> gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
-            >>> gt_cls = torch.tensor([1, 0])
-            >>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
+            >>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
+            >>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
+            >>> correct_preds = validator._process_batch(preds, batch)
         """
-        if masks:
-            if overlap:
+        tp = super()._process_batch(preds, batch)
+        gt_cls, gt_masks = batch["cls"], batch["masks"]
+        if len(gt_cls) == 0 or len(preds["cls"]) == 0:
+            tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
+        else:
+            pred_masks = preds["masks"]
+            if self.args.overlap_mask:
                 nl = len(gt_cls)
                 index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
                 gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640)
@@ -272,60 +189,32 @@ class SegmentationValidator(DetectionValidator):
                 gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
                 gt_masks = gt_masks.gt_(0.5)
             iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
-        else:  # boxes
-            iou = box_iou(gt_bboxes, detections[:, :4])
-
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
+            tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
+        tp.update({"tp_m": tp_m})  # update tp with mask IoU
+        return tp
 
-    def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
-        """
-        Plot validation samples with bounding box labels and masks.
-
-        Args:
-            batch (Dict[str, Any]): Batch containing images and annotations.
-            ni (int): Batch index.
-        """
-        plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
-            masks=batch["masks"],
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )
-
-    def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
+    def plot_predictions(self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int) -> None:
         """
         Plot batch predictions with masks and bounding boxes.
 
         Args:
             batch (Dict[str, Any]): Batch containing images and annotations.
-            preds (List[torch.Tensor]): List of predictions from the model.
+            preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
             ni (int): Batch index.
         """
-        plot_images(
-            batch["img"],
-            *output_to_target(preds[0], max_det=50),  # not set to self.args.max_det due to slow plotting speed
-            torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )  # pred
-        self.plot_masks.clear()
-
-    def save_one_txt(
-        self, predn: torch.Tensor, pred_masks: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path
-    ) -> None:
+        for p in preds:
+            masks = p["masks"]
+            if masks.shape[0] > 50:
+                LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
+            p["masks"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()
+        super().plot_predictions(batch, preds, ni, max_det=50)  # plot bboxes
+
+    def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         Save YOLO detections to a txt file in normalized coordinates in a specific format.
 
         Args:
             predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
-            pred_masks (torch.Tensor): Predicted masks.
             save_conf (bool): Whether to save confidence scores.
             shape (Tuple[int, int]): Shape of the original image.
             file (Path): File path to save the detections.
@@ -336,18 +225,17 @@ class SegmentationValidator(DetectionValidator):
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             names=self.names,
-            boxes=predn[:, :6],
-            masks=pred_masks,
+            boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
+            masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
         ).save_txt(file, save_conf=save_conf)
 
-    def pred_to_json(self, predn: torch.Tensor, filename: str, pred_masks: torch.Tensor) -> None:
+    def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
         """
         Save one JSON result for COCO evaluation.
 
         Args:
-            predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
+            predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
             filename (str): Image filename.
-            pred_masks (numpy.ndarray): Predicted masks.
 
         Examples:
              >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
@@ -362,18 +250,18 @@ class SegmentationValidator(DetectionValidator):
 
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
-        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box = ops.xyxy2xywh(predn["bboxes"])  # xywh
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
-        pred_masks = np.transpose(pred_masks, (2, 0, 1))
+        pred_masks = np.transpose(predn["coco_masks"], (2, 0, 1))
         with ThreadPool(NUM_THREADS) as pool:
             rles = pool.map(single_encode, pred_masks)
-        for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
+        for i, (b, s, c) in enumerate(zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist())):
             self.jdict.append(
                 {
                     "image_id": image_id,
-                    "category_id": self.class_map[int(p[5])],
+                    "category_id": self.class_map[int(c)],
                     "bbox": [round(x, 3) for x in b],
-                    "score": round(p[4], 5),
+                    "score": round(s, 5),
                     "segmentation": rles[i],
                 }
             )
Discard
@@ -457,7 +457,7 @@ def _log_plots(experiment, trainer) -> None:
         >>> _log_plots(experiment, trainer)
     """
     plot_filenames = None
-    if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment":
+    if isinstance(trainer.validator.metrics, SegmentMetrics):
         plot_filenames = [
             trainer.save_dir / f"{prefix}{plots}.png"
             for plots in EVALUATION_PLOT_NAMES
Discard
@@ -4,7 +4,7 @@
 import math
 import warnings
 from pathlib import Path
-from typing import Dict, List, Tuple, Union
+from typing import Any, Dict, List, Tuple, Union
 
 import numpy as np
 import torch
@@ -316,28 +316,22 @@ class ConfusionMatrix(DataExportMixin):
     Attributes:
         task (str): The type of task, either 'detect' or 'classify'.
         matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
-        nc (int): The number of classes.
-        conf (float): The confidence threshold for detections.
-        iou_thres (float): The Intersection over Union threshold.
+        nc (int): The number of category.
+        names (List[str]): The names of the classes, used as labels on the plot.
     """
 
-    def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, names: tuple = (), task: str = "detect"):
+    def __init__(self, names: List[str] = [], task: str = "detect"):
         """
         Initialize a ConfusionMatrix instance.
 
         Args:
-            nc (int): Number of classes.
-            conf (float, optional): Confidence threshold for detections.
-            iou_thres (float, optional): IoU threshold for matching detections to ground truth.
-            names (tuple, optional): Names of classes, used as labels on the plot.
+            names (List[str], optional): Names of classes, used as labels on the plot.
             task (str, optional): Type of task, either 'detect' or 'classify'.
         """
         self.task = task
-        self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
-        self.nc = nc  # number of classes
-        self.names = list(names)  # name of classes
-        self.conf = 0.25 if conf in {None, 0.001} else conf  # apply 0.25 if default val conf is passed
-        self.iou_thres = iou_thres
+        self.nc = len(names)  # number of classes
+        self.matrix = np.zeros((self.nc + 1, self.nc + 1)) if self.task == "detect" else np.zeros((self.nc, self.nc))
+        self.names = names  # name of classes
 
     def process_cls_preds(self, preds, targets):
         """
@@ -351,41 +345,45 @@ class ConfusionMatrix(DataExportMixin):
         for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
             self.matrix[p][t] += 1
 
-    def process_batch(self, detections, gt_bboxes, gt_cls):
+    def process_batch(
+        self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
+    ) -> None:
         """
         Update confusion matrix for object detection task.
 
         Args:
-            detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
-                                      Each row should contain (x1, y1, x2, y2, conf, class)
-                                      or with an additional element `angle` when it's obb.
-            gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
-            gt_cls (Array[M]): The class labels.
+            detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
+                                       Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
+                                       Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
+            batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
+                'cls' (Array[M]) keys, where M is the number of ground truth objects.
+            conf (float, optional): Confidence threshold for detections.
+            iou_thres (float, optional): IoU threshold for matching detections to ground truth.
         """
+        conf = 0.25 if conf in {None, 0.001} else conf  # apply 0.25 if default val conf is passed
+        gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
+        no_pred = len(detections["cls"]) == 0
         if gt_cls.shape[0] == 0:  # Check if labels is empty
-            if detections is not None:
-                detections = detections[detections[:, 4] > self.conf]
-                detection_classes = detections[:, 5].int().tolist()
+            if not no_pred:
+                detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
+                detection_classes = detections["cls"].int().tolist()
                 for dc in detection_classes:
                     self.matrix[dc, self.nc] += 1  # false positives
             return
-        if detections is None:
+        if no_pred:
             gt_classes = gt_cls.int().tolist()
             for gc in gt_classes:
                 self.matrix[self.nc, gc] += 1  # background FN
             return
 
-        detections = detections[detections[:, 4] > self.conf]
+        detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
         gt_classes = gt_cls.int().tolist()
-        detection_classes = detections[:, 5].int().tolist()
-        is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5  # with additional `angle` dimension
-        iou = (
-            batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
-            if is_obb
-            else box_iou(gt_bboxes, detections[:, :4])
-        )
-
-        x = torch.where(iou > self.iou_thres)
+        detection_classes = detections["cls"].int().tolist()
+        bboxes = detections["bboxes"]
+        is_obb = bboxes.shape[1] == 5  # check if detections contains angle for OBB
+        iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
+
+        x = torch.where(iou > iou_thres)
         if x[0].shape[0]:
             matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
             if x[0].shape[0] > 1:
@@ -949,53 +947,76 @@ class DetMetrics(SimpleClass, DataExportMixin):
     Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
 
     Attributes:
-        save_dir (Path): A path to the directory where the output plots will be saved.
-        plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
         names (Dict[int, str]): A dictionary of class names.
         box (Metric): An instance of the Metric class for storing detection results.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'detect'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
     """
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         Initialize a DetMetrics instance with a save directory, plot flag, and class names.
 
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
         """
-        self.save_dir = save_dir
-        self.plot = plot
         self.names = names
         self.box = Metric()
         self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
         self.task = "detect"
+        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+        self.nt_per_class = None
+        self.nt_per_image = None
 
-    def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
+    def update_stats(self, stat: Dict[str, Any]) -> None:
+        """
+        Update statistics by appending new values to existing stat collections.
+
+        Args:
+            stat (Dict[str, any]): Dictionary containing new statistical values to append.
+                         Keys should match existing keys in self.stats.
+        """
+        for k in self.stats.keys():
+            self.stats[k].append(stat[k])
+
+    def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
         """
         Process predicted results for object detection and update metrics.
 
         Args:
-            tp (np.ndarray): True positive array.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
-            on_plot (callable, optional): Function to call after plots are generated.
+            save_dir (Path): Directory to save plots. Defaults to Path(".").
+            plot (bool): Whether to plot precision-recall curves. Defaults to False.
+            on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
+
+        Returns:
+            (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
         """
+        stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()}  # to numpy
+        if len(stats) == 0:
+            return stats
         results = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            save_dir=self.save_dir,
+            stats["tp"],
+            stats["conf"],
+            stats["pred_cls"],
+            stats["target_cls"],
+            plot=plot,
+            save_dir=save_dir,
             names=self.names,
             on_plot=on_plot,
         )[2:]
         self.box.nc = len(self.names)
         self.box.update(results)
+        self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
+        self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
+        return stats
+
+    def clear_stats(self):
+        """Clear the stored statistics."""
+        for v in self.stats.values():
+            v.clear()
 
     @property
     def keys(self) -> List[str]:
@@ -1077,92 +1098,65 @@ class DetMetrics(SimpleClass, DataExportMixin):
         ]
 
 
-class SegmentMetrics(SimpleClass, DataExportMixin):
+class SegmentMetrics(DetMetrics):
     """
     Calculate and aggregate detection and segmentation metrics over a given set of classes.
 
     Attributes:
-        save_dir (Path): Path to the directory where the output plots should be saved.
-        plot (bool): Whether to save the detection and segmentation plots.
         names (Dict[int, str]): Dictionary of class names.
         box (Metric): An instance of the Metric class for storing detection results.
         seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'segment'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
     """
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
 
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
         """
-        self.save_dir = save_dir
-        self.plot = plot
-        self.names = names
-        self.box = Metric()
+        DetMetrics.__init__(self, names)
         self.seg = Metric()
-        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
         self.task = "segment"
+        self.stats["tp_m"] = []  # add additional stats for masks
 
-    def process(
-        self,
-        tp: np.ndarray,
-        tp_m: np.ndarray,
-        conf: np.ndarray,
-        pred_cls: np.ndarray,
-        target_cls: np.ndarray,
-        on_plot=None,
-    ):
+    def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
         """
         Process the detection and segmentation metrics over the given set of predictions.
 
         Args:
-            tp (np.ndarray): True positive array for boxes.
-            tp_m (np.ndarray): True positive array for masks.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
-            on_plot (callable, optional): Function to call after plots are generated.
+            save_dir (Path): Directory to save plots. Defaults to Path(".").
+            plot (bool): Whether to plot precision-recall curves. Defaults to False.
+            on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
+
+        Returns:
+            (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
         """
+        stats = DetMetrics.process(self, on_plot=on_plot)  # process box stats
         results_mask = ap_per_class(
-            tp_m,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
+            stats["tp_m"],
+            stats["conf"],
+            stats["pred_cls"],
+            stats["target_cls"],
+            plot=plot,
             on_plot=on_plot,
-            save_dir=self.save_dir,
+            save_dir=save_dir,
             names=self.names,
             prefix="Mask",
         )[2:]
         self.seg.nc = len(self.names)
         self.seg.update(results_mask)
-        results_box = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            on_plot=on_plot,
-            save_dir=self.save_dir,
-            names=self.names,
-            prefix="Box",
-        )[2:]
-        self.box.nc = len(self.names)
-        self.box.update(results_box)
+        return stats
 
     @property
     def keys(self) -> List[str]:
         """Return a list of keys for accessing metrics."""
-        return [
-            "metrics/precision(B)",
-            "metrics/recall(B)",
-            "metrics/mAP50(B)",
-            "metrics/mAP50-95(B)",
+        return DetMetrics.keys.fget(self) + [
             "metrics/precision(M)",
             "metrics/recall(M)",
             "metrics/mAP50(M)",
@@ -1171,40 +1165,26 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
 
     def mean_results(self) -> List[float]:
         """Return the mean metrics for bounding box and segmentation results."""
-        return self.box.mean_results() + self.seg.mean_results()
+        return DetMetrics.mean_results(self) + self.seg.mean_results()
 
     def class_result(self, i: int) -> List[float]:
         """Return classification results for a specified class index."""
-        return self.box.class_result(i) + self.seg.class_result(i)
+        return DetMetrics.class_result(self, i) + self.seg.class_result(i)
 
     @property
     def maps(self) -> np.ndarray:
         """Return mAP scores for object detection and semantic segmentation models."""
-        return self.box.maps + self.seg.maps
+        return DetMetrics.maps.fget(self) + self.seg.maps
 
     @property
     def fitness(self) -> float:
         """Return the fitness score for both segmentation and bounding box models."""
-        return self.seg.fitness() + self.box.fitness()
-
-    @property
-    def ap_class_index(self) -> List:
-        """Return the class indices (boxes and masks have the same ap_class_index)."""
-        return self.box.ap_class_index
-
-    @property
-    def results_dict(self) -> Dict[str, float]:
-        """Return results of object detection model for evaluation."""
-        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
+        return self.seg.fitness() + DetMetrics.fitness.fget(self)
 
     @property
     def curves(self) -> List[str]:
         """Return a list of curves for accessing specific metrics curves."""
-        return [
-            "Precision-Recall(B)",
-            "F1-Confidence(B)",
-            "Precision-Confidence(B)",
-            "Recall-Confidence(B)",
+        return DetMetrics.curves.fget(self) + [
             "Precision-Recall(M)",
             "F1-Confidence(M)",
             "Precision-Confidence(M)",
@@ -1214,7 +1194,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
     @property
     def curves_results(self) -> List[List]:
         """Return dictionary of computed performance metrics and statistics."""
-        return self.box.curves_results + self.seg.curves_results
+        return DetMetrics.curves_results.fget(self) + self.seg.curves_results
 
     def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
         """
@@ -1234,43 +1214,34 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
             >>> print(seg_summary)
         """
         scalars = {
-            "box-map": round(self.box.map, decimals),
-            "box-map50": round(self.box.map50, decimals),
-            "box-map75": round(self.box.map75, decimals),
             "mask-map": round(self.seg.map, decimals),
             "mask-map50": round(self.seg.map50, decimals),
             "mask-map75": round(self.seg.map75, decimals),
         }
         per_class = {
-            "box-p": self.box.p,
-            "box-r": self.box.r,
-            "box-f1": self.box.f1,
             "mask-p": self.seg.p,
             "mask-r": self.seg.r,
             "mask-f1": self.seg.f1,
         }
-        return [
-            {
-                "class_name": self.names[self.ap_class_index[i]],
-                **{k: round(v[i], decimals) for k, v in per_class.items()},
-                **scalars,
-            }
-            for i in range(len(per_class["box-p"]))
-        ]
+        summary = DetMetrics.summary(self, normalize, decimals)  # get box summary
+        for i, s in enumerate(summary):
+            s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
+        return summary
 
 
-class PoseMetrics(SegmentMetrics):
+class PoseMetrics(DetMetrics):
     """
     Calculate and aggregate detection and pose metrics over a given set of classes.
 
     Attributes:
-        save_dir (Path): Path to the directory where the output plots should be saved.
-        plot (bool): Whether to save the detection and pose plots.
         names (Dict[int, str]): Dictionary of class names.
         pose (Metric): An instance of the Metric class to calculate pose metrics.
         box (Metric): An instance of the Metric class for storing detection results.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'pose'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
 
     Methods:
         process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
@@ -1282,79 +1253,50 @@ class PoseMetrics(SegmentMetrics):
         results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
     """
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         Initialize the PoseMetrics class with directory path, class names, and plotting options.
 
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
         """
-        super().__init__(save_dir, plot, names)
-        self.save_dir = save_dir
-        self.plot = plot
-        self.names = names
-        self.box = Metric()
+        super().__init__(names)
         self.pose = Metric()
-        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
         self.task = "pose"
+        self.stats["tp_p"] = []  # add additional stats for pose
 
-    def process(
-        self,
-        tp: np.ndarray,
-        tp_p: np.ndarray,
-        conf: np.ndarray,
-        pred_cls: np.ndarray,
-        target_cls: np.ndarray,
-        on_plot=None,
-    ):
+    def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
         """
         Process the detection and pose metrics over the given set of predictions.
 
         Args:
-            tp (np.ndarray): True positive array for boxes.
-            tp_p (np.ndarray): True positive array for keypoints.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
+            save_dir (Path): Directory to save plots. Defaults to Path(".").
+            plot (bool): Whether to plot precision-recall curves. Defaults to False.
             on_plot (callable, optional): Function to call after plots are generated.
+
+        Returns:
+            (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
         """
+        stats = DetMetrics.process(self, on_plot=on_plot)  # process box stats
         results_pose = ap_per_class(
-            tp_p,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
+            stats["tp_p"],
+            stats["conf"],
+            stats["pred_cls"],
+            stats["target_cls"],
+            plot=plot,
             on_plot=on_plot,
-            save_dir=self.save_dir,
+            save_dir=save_dir,
             names=self.names,
             prefix="Pose",
         )[2:]
         self.pose.nc = len(self.names)
         self.pose.update(results_pose)
-        results_box = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            on_plot=on_plot,
-            save_dir=self.save_dir,
-            names=self.names,
-            prefix="Box",
-        )[2:]
-        self.box.nc = len(self.names)
-        self.box.update(results_box)
+        return stats
 
     @property
     def keys(self) -> List[str]:
         """Return list of evaluation metric keys."""
-        return [
-            "metrics/precision(B)",
-            "metrics/recall(B)",
-            "metrics/mAP50(B)",
-            "metrics/mAP50-95(B)",
+        return DetMetrics.keys.fget(self) + [
             "metrics/precision(P)",
             "metrics/recall(P)",
             "metrics/mAP50(P)",
@@ -1363,26 +1305,26 @@ class PoseMetrics(SegmentMetrics):
 
     def mean_results(self) -> List[float]:
         """Return the mean results of box and pose."""
-        return self.box.mean_results() + self.pose.mean_results()
+        return DetMetrics.mean_results(self) + self.pose.mean_results()
 
     def class_result(self, i: int) -> List[float]:
         """Return the class-wise detection results for a specific class i."""
-        return self.box.class_result(i) + self.pose.class_result(i)
+        return DetMetrics.class_result(self, i) + self.pose.class_result(i)
 
     @property
     def maps(self) -> np.ndarray:
         """Return the mean average precision (mAP) per class for both box and pose detections."""
-        return self.box.maps + self.pose.maps
+        return DetMetrics.maps.fget(self) + self.pose.maps
 
     @property
     def fitness(self) -> float:
         """Return combined fitness score for pose and box detection."""
-        return self.pose.fitness() + self.box.fitness()
+        return self.pose.fitness() + DetMetrics.fitness.fget(self)
 
     @property
     def curves(self) -> List[str]:
         """Return a list of curves for accessing specific metrics curves."""
-        return [
+        return DetMetrics.curves.fget(self) + [
             "Precision-Recall(B)",
             "F1-Confidence(B)",
             "Precision-Confidence(B)",
@@ -1396,7 +1338,7 @@ class PoseMetrics(SegmentMetrics):
     @property
     def curves_results(self) -> List[List]:
         """Return dictionary of computed performance metrics and statistics."""
-        return self.box.curves_results + self.pose.curves_results
+        return DetMetrics.curves_results.fget(self) + self.pose.curves_results
 
     def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
         """
@@ -1416,29 +1358,19 @@ class PoseMetrics(SegmentMetrics):
             >>> print(pose_summary)
         """
         scalars = {
-            "box-map": round(self.box.map, decimals),
-            "box-map50": round(self.box.map50, decimals),
-            "box-map75": round(self.box.map75, decimals),
             "pose-map": round(self.pose.map, decimals),
             "pose-map50": round(self.pose.map50, decimals),
             "pose-map75": round(self.pose.map75, decimals),
         }
         per_class = {
-            "box-p": self.box.p,
-            "box-r": self.box.r,
-            "box-f1": self.box.f1,
             "pose-p": self.pose.p,
             "pose-r": self.pose.r,
             "pose-f1": self.pose.f1,
         }
-        return [
-            {
-                "class_name": self.names[self.ap_class_index[i]],
-                **{k: round(v[i], decimals) for k, v in per_class.items()},
-                **scalars,
-            }
-            for i in range(len(per_class["box-p"]))
-        ]
+        summary = DetMetrics.summary(self, normalize, decimals)  # get box summary
+        for i, s in enumerate(summary):
+            s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
+        return summary
 
 
 class ClassifyMetrics(SimpleClass, DataExportMixin):
@@ -1516,133 +1448,30 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
         return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
 
 
-class OBBMetrics(SimpleClass, DataExportMixin):
+class OBBMetrics(DetMetrics):
     """
     Metrics for evaluating oriented bounding box (OBB) detection.
 
     Attributes:
-        save_dir (Path): Path to the directory where the output plots should be saved.
-        plot (bool): Whether to save the detection plots.
         names (Dict[int, str]): Dictionary of class names.
         box (Metric): An instance of the Metric class for storing detection results.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'obb'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
 
     References:
         https://arxiv.org/pdf/2106.06072.pdf
     """
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         Initialize an OBBMetrics instance with directory, plotting, and class names.
 
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
         """
-        self.save_dir = save_dir
-        self.plot = plot
-        self.names = names
-        self.box = Metric()
-        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+        DetMetrics.__init__(self, names)
+        # TODO: probably remove task as well
         self.task = "obb"
-
-    def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
-        """
-        Process predicted results for object detection and update metrics.
-
-        Args:
-            tp (np.ndarray): True positive array.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
-            on_plot (callable, optional): Function to call after plots are generated.
-        """
-        results = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            save_dir=self.save_dir,
-            names=self.names,
-            on_plot=on_plot,
-        )[2:]
-        self.box.nc = len(self.names)
-        self.box.update(results)
-
-    @property
-    def keys(self) -> List[str]:
-        """Return a list of keys for accessing specific metrics."""
-        return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
-
-    def mean_results(self) -> List[float]:
-        """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
-        return self.box.mean_results()
-
-    def class_result(self, i: int) -> Tuple[float, float, float, float]:
-        """Return the result of evaluating the performance of an object detection model on a specific class."""
-        return self.box.class_result(i)
-
-    @property
-    def maps(self) -> np.ndarray:
-        """Return mean Average Precision (mAP) scores per class."""
-        return self.box.maps
-
-    @property
-    def fitness(self) -> float:
-        """Return the fitness of box object."""
-        return self.box.fitness()
-
-    @property
-    def ap_class_index(self) -> List:
-        """Return the average precision index per class."""
-        return self.box.ap_class_index
-
-    @property
-    def results_dict(self) -> Dict[str, float]:
-        """Return dictionary of computed performance metrics and statistics."""
-        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
-
-    @property
-    def curves(self) -> List:
-        """Return a list of curves for accessing specific metrics curves."""
-        return []
-
-    @property
-    def curves_results(self) -> List:
-        """Return a list of curves for accessing specific metrics curves."""
-        return []
-
-    def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
-        """
-        Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
-        scalar metrics (mAP, mAP50, mAP75) along with precision, recall, and F1-score for each class.
-
-        Args:
-            normalize (bool): For OBB metrics, everything is normalized  by default [0-1].
-            decimals (int): Number of decimal places to round the metrics values to.
-
-        Returns:
-            (List[Dict[str, Union[str, float]]]): A list of dictionaries, each representing one class with detection metrics.
-
-        Examples:
-            >>> results = model.val(data="dota8.yaml")
-            >>> detection_summary = results.summary(decimals=4)
-            >>> print(detection_summary)
-        """
-        scalars = {
-            "box-map": round(self.box.map, decimals),
-            "box-map50": round(self.box.map50, decimals),
-            "box-map75": round(self.box.map75, decimals),
-        }
-        per_class = {"box-p": self.box.p, "box-r": self.box.r, "box-f1": self.box.f1}
-        return [
-            {
-                "class_name": self.names[self.ap_class_index[i]],
-                **{k: round(v[i], decimals) for k, v in per_class.items()},
-                **scalars,
-            }
-            for i in range(len(per_class["box-p"]))
-        ]
Discard
@@ -255,7 +255,7 @@ def non_max_suppression(
 
     bs = prediction.shape[0]  # batch size (BCN, i.e. 1,84,6300)
     nc = nc or (prediction.shape[1] - 4)  # number of classes
-    nm = prediction.shape[1] - nc - 4  # number of masks
+    extra = prediction.shape[1] - nc - 4  # number of extra info
     mi = 4 + nc  # mask start index
     xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates
     xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None]  # to track idxs
@@ -273,7 +273,7 @@ def non_max_suppression(
             prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1)  # xywh to xyxy
 
     t = time.time()
-    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
+    output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
     keepi = [torch.zeros((0, 1), device=prediction.device)] * bs  # to store the kept idxs
     for xi, (x, xk) in enumerate(zip(prediction, xinds)):  # image index, (preds, preds indices)
         # Apply constraints
@@ -284,7 +284,7 @@ def non_max_suppression(
         # Cat apriori labels if autolabelling
         if labels and len(labels[xi]) and not rotated:
             lb = labels[xi]
-            v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
+            v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
             v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
             v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
             x = torch.cat((x, v), 0)
@@ -294,7 +294,7 @@ def non_max_suppression(
             continue
 
         # Detections matrix nx6 (xyxy, conf, cls)
-        box, cls, mask = x.split((4, nc, nm), 1)
+        box, cls, mask = x.split((4, nc, extra), 1)
 
         if multi_label:
             i, j = torch.where(cls > conf_thres)
Discard
@@ -3,7 +3,7 @@
 import math
 import warnings
 from pathlib import Path
-from typing import Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
 
 import cv2
 import numpy as np
@@ -678,13 +678,8 @@ def save_one_box(
 
 @threaded
 def plot_images(
-    images: Union[torch.Tensor, np.ndarray],
-    batch_idx: Union[torch.Tensor, np.ndarray],
-    cls: Union[torch.Tensor, np.ndarray],
-    bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
-    confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
-    masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
-    kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
+    labels: Dict[str, Any],
+    images: Union[torch.Tensor, np.ndarray] = np.zeros((0, 3, 640, 640), dtype=np.float32),
     paths: Optional[List[str]] = None,
     fname: str = "images.jpg",
     names: Optional[Dict[int, str]] = None,
@@ -698,21 +693,16 @@ def plot_images(
     Plot image grid with labels, bounding boxes, masks, and keypoints.
 
     Args:
-        images: Batch of images to plot. Shape: (batch_size, channels, height, width).
-        batch_idx: Batch indices for each detection. Shape: (num_detections,).
-        cls: Class labels for each detection. Shape: (num_detections,).
-        bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
-        confs: Confidence scores for each detection. Shape: (num_detections,).
-        masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
-        kpts: Keypoints for each detection. Shape: (num_detections, 51).
-        paths: List of file paths for each image in the batch.
-        fname: Output filename for the plotted image grid.
-        names: Dictionary mapping class indices to class names.
-        on_plot: Optional callback function to be called after saving the plot.
-        max_size: Maximum size of the output image grid.
-        max_subplots: Maximum number of subplots in the image grid.
-        save: Whether to save the plotted image grid to a file.
-        conf_thres: Confidence threshold for displaying detections.
+        labels (Dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.
+        images (Union[torch.Tensor, np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
+        paths (Optional[List[str]]): List of file paths for each image in the batch.
+        fname (str): Output filename for the plotted image grid.
+        names (Optional[Dict[int, str]]): Dictionary mapping class indices to class names.
+        on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.
+        max_size (int): Maximum size of the output image grid.
+        max_subplots (int): Maximum number of subplots in the image grid.
+        save (bool): Whether to save the plotted image grid to a file.
+        conf_thres (float): Confidence threshold for displaying detections.
 
     Returns:
         (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
@@ -721,18 +711,24 @@ def plot_images(
         This function supports both tensor and numpy array inputs. It will automatically
         convert tensor inputs to numpy arrays for processing.
     """
-    if isinstance(images, torch.Tensor):
+    for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
+        if k not in labels:
+            continue
+        if k == "cls" and labels[k].ndim == 2:
+            labels[k] = labels[k].squeeze(1)  # squeeze if shape is (n, 1)
+        if isinstance(labels[k], torch.Tensor):
+            labels[k] = labels[k].cpu().numpy()
+
+    cls = labels.get("cls", np.zeros(0, dtype=np.int64))
+    batch_idx = labels.get("batch_idx", np.zeros(cls.shape, dtype=np.int64))
+    bboxes = labels.get("bboxes", np.zeros(0, dtype=np.float32))
+    confs = labels.get("conf", None)
+    masks = labels.get("masks", np.zeros(0, dtype=np.uint8))
+    kpts = labels.get("keypoints", np.zeros(0, dtype=np.float32))
+    images = labels.get("img", images)  # default to input images
+
+    if len(images) and isinstance(images, torch.Tensor):
         images = images.cpu().float().numpy()
-    if isinstance(cls, torch.Tensor):
-        cls = cls.cpu().numpy()
-    if isinstance(bboxes, torch.Tensor):
-        bboxes = bboxes.cpu().numpy()
-    if isinstance(masks, torch.Tensor):
-        masks = masks.cpu().numpy().astype(int)
-    if isinstance(kpts, torch.Tensor):
-        kpts = kpts.cpu().numpy()
-    if isinstance(batch_idx, torch.Tensor):
-        batch_idx = batch_idx.cpu().numpy()
     if images.shape[1] > 3:
         images = images[:, :3]  # crop multispectral images to first 3 channels
 
@@ -781,6 +777,7 @@ def plot_images(
                 boxes[..., 0] += x
                 boxes[..., 1] += y
                 is_obb = boxes.shape[-1] == 5  # xywhr
+                # TODO: this transformation might be unnecessary
                 boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
                 for j, box in enumerate(boxes.astype(np.int64).tolist()):
                     c = classes[j]
@@ -1004,28 +1001,6 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
     _save_one_file(csv_file.with_name("tune_fitness.png"))
 
 
-def output_to_target(output, max_det: int = 300):
-    """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
-    targets = []
-    for i, o in enumerate(output):
-        box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
-        j = torch.full((conf.shape[0], 1), i)
-        targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
-    targets = torch.cat(targets, 0).numpy()
-    return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
-
-
-def output_to_rotated_target(output, max_det: int = 300):
-    """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
-    targets = []
-    for i, o in enumerate(output):
-        box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
-        j = torch.full((conf.shape[0], 1), i)
-        targets.append(torch.cat((j, cls, box, angle, conf), 1))
-    targets = torch.cat(targets, 0).numpy()
-    return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
-
-
 def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
     """
     Visualize feature maps of a given model module during inference.
Discard