Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#21009 `ultralytics 8.3.154` Refactor `Validator` and `Metrics` classes

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:validator-cleanup
@@ -143,8 +143,11 @@ To train a YOLO11 model using JupyterLab:
 5. Visualize training results using JupyterLab's built-in plotting capabilities:
 5. Visualize training results using JupyterLab's built-in plotting capabilities:
 
 
     ```python
     ```python
-    %matplotlib inline
+    import matplotlib
+
     from ultralytics.utils.plotting import plot_results
     from ultralytics.utils.plotting import plot_results
+
+    matplotlib.use("inline")  # or 'notebook' for interactive
     plot_results(results)
     plot_results(results)
     ```
     ```
 
 
Discard
@@ -325,7 +325,7 @@ To use YOLOv7 ONNX model with Ultralytics:
 
 
 2. Install the `TensorRT` Python package:
 2. Install the `TensorRT` Python package:
 
 
-    ```python
+    ```bash
     pip install tensorrt
     pip install tensorrt
     ```
     ```
 
 
Discard
@@ -43,14 +43,6 @@ keywords: ultralytics, plotting, utilities, documentation, data visualization, a
 
 
 <br><br><hr><br>
 <br><br><hr><br>
 
 
-## ::: ultralytics.utils.plotting.output_to_target
-
-<br><br><hr><br>
-
-## ::: ultralytics.utils.plotting.output_to_rotated_target
-
-<br><br><hr><br>
-
 ## ::: ultralytics.utils.plotting.feature_visualization
 ## ::: ultralytics.utils.plotting.feature_visualization
 
 
 <br><br>
 <br><br>
Discard
@@ -76,16 +76,21 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
 
 
     Ultralytics YOLO classification uses [torchvision.transforms.RandomResizedCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) for training augmentation and [torchvision.transforms.CenterCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html) for validation/inference.
     Ultralytics YOLO classification uses [torchvision.transforms.RandomResizedCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) for training augmentation and [torchvision.transforms.CenterCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html) for validation/inference.
     For images with extreme aspect ratios, consider using [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) instead. The example below shows how to customize augmentations for classification training.
     For images with extreme aspect ratios, consider using [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) instead. The example below shows how to customize augmentations for classification training.
+
     ```python
     ```python
     import torch
     import torch
     import torchvision.transforms as T
     import torchvision.transforms as T
 
 
+    from ultralytics import YOLO
     from ultralytics.data.dataset import ClassificationDataset
     from ultralytics.data.dataset import ClassificationDataset
     from ultralytics.models.yolo.classify import ClassificationTrainer
     from ultralytics.models.yolo.classify import ClassificationTrainer
 
 
 
 
     class CustomizedDataset(ClassificationDataset):
     class CustomizedDataset(ClassificationDataset):
+        """A customized dataset class for image classification with enhanced data augmentation transforms."""
+
         def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
         def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
+            """Initialize a customized classification dataset with enhanced data augmentation transforms."""
             super().__init__(root, args, augment, prefix)
             super().__init__(root, args, augment, prefix)
             train_transforms = T.Compose(
             train_transforms = T.Compose(
                 [
                 [
@@ -110,12 +115,13 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
 
 
 
 
     class CustomizedTrainer(ClassificationTrainer):
     class CustomizedTrainer(ClassificationTrainer):
+        """A customized trainer class for YOLO classification models with enhanced dataset handling."""
+
         def build_dataset(self, img_path: str, mode: str = "train", batch=None):
         def build_dataset(self, img_path: str, mode: str = "train", batch=None):
+            """Build a customized dataset for classification training or validation."""
             return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
             return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
 
 
 
 
-    from ultralytics import YOLO
-
     model = YOLO("yolo11n-cls.pt")
     model = YOLO("yolo11n-cls.pt")
     model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
     model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
     ```
     ```
Discard
@@ -1,6 +1,6 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
 
-__version__ = "8.3.153"
+__version__ = "8.3.154"
 
 
 import os
 import os
 
 
Discard
@@ -82,7 +82,6 @@ class BaseValidator:
         update_metrics: Update metrics based on predictions and batch.
         update_metrics: Update metrics based on predictions and batch.
         finalize_metrics: Finalize and return all metrics.
         finalize_metrics: Finalize and return all metrics.
         get_stats: Return statistics about the model's performance.
         get_stats: Return statistics about the model's performance.
-        check_stats: Check statistics.
         print_results: Print the results of the model's predictions.
         print_results: Print the results of the model's predictions.
         get_desc: Get description of the YOLO model.
         get_desc: Get description of the YOLO model.
         on_plot: Register plots for visualization.
         on_plot: Register plots for visualization.
@@ -226,7 +225,6 @@ class BaseValidator:
 
 
             self.run_callbacks("on_val_batch_end")
             self.run_callbacks("on_val_batch_end")
         stats = self.get_stats()
         stats = self.get_stats()
-        self.check_stats(stats)
         self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
         self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
         self.finalize_metrics()
         self.finalize_metrics()
         self.print_results()
         self.print_results()
@@ -334,10 +332,6 @@ class BaseValidator:
         """Return statistics about the model's performance."""
         """Return statistics about the model's performance."""
         return {}
         return {}
 
 
-    def check_stats(self, stats):
-        """Check statistics."""
-        pass
-
     def print_results(self):
     def print_results(self):
         """Print the results of the model's predictions."""
         """Print the results of the model's predictions."""
         pass
         pass
Discard
@@ -1,7 +1,6 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
 
 from ultralytics.models.yolo.segment import SegmentationValidator
 from ultralytics.models.yolo.segment import SegmentationValidator
-from ultralytics.utils.metrics import SegmentMetrics
 
 
 
 
 class FastSAMValidator(SegmentationValidator):
 class FastSAMValidator(SegmentationValidator):
@@ -39,4 +38,3 @@ class FastSAMValidator(SegmentationValidator):
         super().__init__(dataloader, save_dir, args, _callbacks)
         super().__init__(dataloader, save_dir, args, _callbacks)
         self.args.task = "segment"
         self.args.task = "segment"
         self.args.plots = False  # disable ConfusionMatrix and other plots to avoid errors
         self.args.plots = False  # disable ConfusionMatrix and other plots to avoid errors
-        self.metrics = SegmentMetrics(save_dir=self.save_dir)
Discard
@@ -1,5 +1,7 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
 
+from typing import Any, Dict, List, Tuple, Union
+
 import torch
 import torch
 
 
 from ultralytics.data import YOLODataset
 from ultralytics.data import YOLODataset
@@ -151,15 +153,21 @@ class RTDETRValidator(DetectionValidator):
             data=self.data,
             data=self.data,
         )
         )
 
 
-    def postprocess(self, preds):
+    def postprocess(
+        self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
+    ) -> List[Dict[str, torch.Tensor]]:
         """
         """
         Apply Non-maximum suppression to prediction outputs.
         Apply Non-maximum suppression to prediction outputs.
 
 
         Args:
         Args:
-            preds (list | tuple | torch.Tensor): Raw predictions from the model.
+            preds (torch.Tensor | List | Tuple): Raw predictions from the model. If tensor, should have shape
+                (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
 
 
         Returns:
         Returns:
-            (list[torch.Tensor]): List of processed predictions for each image in batch.
+            (List[Dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
+                - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
+                - 'conf': Tensor of shape (N,) with confidence scores
+                - 'cls': Tensor of shape (N,) with class indices
         """
         """
         if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
         if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
             preds = [preds, None]
             preds = [preds, None]
@@ -176,18 +184,19 @@ class RTDETRValidator(DetectionValidator):
             pred = pred[score.argsort(descending=True)]
             pred = pred[score.argsort(descending=True)]
             outputs[i] = pred[score > self.args.conf]
             outputs[i] = pred[score > self.args.conf]
 
 
-        return outputs
+        return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
 
 
-    def _prepare_batch(self, si, batch):
+    def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
         Prepare a batch for validation by applying necessary transformations.
         Prepare a batch for validation by applying necessary transformations.
 
 
         Args:
         Args:
             si (int): Batch index.
             si (int): Batch index.
-            batch (dict): Batch data containing images and annotations.
+            batch (Dict[str, Any]): Batch data containing images and annotations.
 
 
         Returns:
         Returns:
-            (dict): Prepared batch with transformed annotations.
+            (Dict[str, Any]): Prepared batch with transformed annotations containing cls, bboxes,
+                ori_shape, imgsz, and ratio_pad.
         """
         """
         idx = batch["batch_idx"] == si
         idx = batch["batch_idx"] == si
         cls = batch["cls"][idx].squeeze(-1)
         cls = batch["cls"][idx].squeeze(-1)
@@ -199,20 +208,23 @@ class RTDETRValidator(DetectionValidator):
             bbox = ops.xywh2xyxy(bbox)  # target boxes
             bbox = ops.xywh2xyxy(bbox)  # target boxes
             bbox[..., [0, 2]] *= ori_shape[1]  # native-space pred
             bbox[..., [0, 2]] *= ori_shape[1]  # native-space pred
             bbox[..., [1, 3]] *= ori_shape[0]  # native-space pred
             bbox[..., [1, 3]] *= ori_shape[0]  # native-space pred
-        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+        return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
 
 
-    def _prepare_pred(self, pred, pbatch):
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         """
         Prepare predictions by scaling bounding boxes to original image dimensions.
         Prepare predictions by scaling bounding boxes to original image dimensions.
 
 
         Args:
         Args:
-            pred (torch.Tensor): Raw predictions.
-            pbatch (dict): Prepared batch information.
+            pred (Dict[str, torch.Tensor]): Raw predictions containing 'cls', 'bboxes', and 'conf'.
+            pbatch (Dict[str, torch.Tensor]): Prepared batch information containing 'ori_shape' and other metadata.
 
 
         Returns:
         Returns:
-            (torch.Tensor): Predictions scaled to original image dimensions.
+            (Dict[str, torch.Tensor]): Predictions scaled to original image dimensions.
         """
         """
-        predn = pred.clone()
-        predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz  # native-space pred
-        predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz  # native-space pred
-        return predn.float()
+        cls = pred["cls"]
+        if self.args.single_cls:
+            cls *= 0
+        bboxes = pred["bboxes"].clone()
+        bboxes[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz  # native-space pred
+        bboxes[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz  # native-space pred
+        return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
Discard
@@ -1,5 +1,8 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
 
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
 import torch
 import torch
 
 
 from ultralytics.data import ClassificationDataset, build_dataloader
 from ultralytics.data import ClassificationDataset, build_dataloader
@@ -48,7 +51,7 @@ class ClassificationValidator(BaseValidator):
         Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
         Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
     """
     """
 
 
-    def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
+    def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
         """
         """
         Initialize ClassificationValidator with dataloader, save directory, and other parameters.
         Initialize ClassificationValidator with dataloader, save directory, and other parameters.
 
 
@@ -70,28 +73,26 @@ class ClassificationValidator(BaseValidator):
         self.args.task = "classify"
         self.args.task = "classify"
         self.metrics = ClassifyMetrics()
         self.metrics = ClassifyMetrics()
 
 
-    def get_desc(self):
+    def get_desc(self) -> str:
         """Return a formatted string summarizing classification metrics."""
         """Return a formatted string summarizing classification metrics."""
         return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
         return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
 
 
-    def init_metrics(self, model):
+    def init_metrics(self, model: torch.nn.Module) -> None:
         """Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
         """Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
         self.names = model.names
         self.names = model.names
         self.nc = len(model.names)
         self.nc = len(model.names)
-        self.confusion_matrix = ConfusionMatrix(
-            nc=self.nc, conf=self.args.conf, names=self.names.values(), task="classify"
-        )
         self.pred = []
         self.pred = []
         self.targets = []
         self.targets = []
+        self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
 
 
-    def preprocess(self, batch):
+    def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
         """Preprocess input batch by moving data to device and converting to appropriate dtype."""
         """Preprocess input batch by moving data to device and converting to appropriate dtype."""
         batch["img"] = batch["img"].to(self.device, non_blocking=True)
         batch["img"] = batch["img"].to(self.device, non_blocking=True)
         batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
         batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
         batch["cls"] = batch["cls"].to(self.device)
         batch["cls"] = batch["cls"].to(self.device)
         return batch
         return batch
 
 
-    def update_metrics(self, preds, batch):
+    def update_metrics(self, preds: torch.Tensor, batch: Dict[str, Any]) -> None:
         """
         """
         Update running metrics with model predictions and batch targets.
         Update running metrics with model predictions and batch targets.
 
 
@@ -127,23 +128,23 @@ class ClassificationValidator(BaseValidator):
             for normalize in True, False:
             for normalize in True, False:
                 self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
                 self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
         self.metrics.speed = self.speed
         self.metrics.speed = self.speed
-        self.metrics.confusion_matrix = self.confusion_matrix
         self.metrics.save_dir = self.save_dir
         self.metrics.save_dir = self.save_dir
+        self.metrics.confusion_matrix = self.confusion_matrix
 
 
-    def postprocess(self, preds):
+    def postprocess(self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> torch.Tensor:
         """Extract the primary prediction from model output if it's in a list or tuple format."""
         """Extract the primary prediction from model output if it's in a list or tuple format."""
         return preds[0] if isinstance(preds, (list, tuple)) else preds
         return preds[0] if isinstance(preds, (list, tuple)) else preds
 
 
-    def get_stats(self):
+    def get_stats(self) -> Dict[str, float]:
         """Calculate and return a dictionary of metrics by processing targets and predictions."""
         """Calculate and return a dictionary of metrics by processing targets and predictions."""
         self.metrics.process(self.targets, self.pred)
         self.metrics.process(self.targets, self.pred)
         return self.metrics.results_dict
         return self.metrics.results_dict
 
 
-    def build_dataset(self, img_path):
+    def build_dataset(self, img_path: str) -> ClassificationDataset:
         """Create a ClassificationDataset instance for validation."""
         """Create a ClassificationDataset instance for validation."""
         return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
         return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
 
 
-    def get_dataloader(self, dataset_path, batch_size):
+    def get_dataloader(self, dataset_path: Union[Path, str], batch_size: int) -> torch.utils.data.DataLoader:
         """
         """
         Build and return a data loader for classification validation.
         Build and return a data loader for classification validation.
 
 
@@ -157,17 +158,17 @@ class ClassificationValidator(BaseValidator):
         dataset = self.build_dataset(dataset_path)
         dataset = self.build_dataset(dataset_path)
         return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
         return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
 
 
-    def print_results(self):
+    def print_results(self) -> None:
         """Print evaluation metrics for the classification model."""
         """Print evaluation metrics for the classification model."""
         pf = "%22s" + "%11.3g" * len(self.metrics.keys)  # print format
         pf = "%22s" + "%11.3g" * len(self.metrics.keys)  # print format
         LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
         LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
 
 
-    def plot_val_samples(self, batch, ni):
+    def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
         """
         """
         Plot validation image samples with their ground truth labels.
         Plot validation image samples with their ground truth labels.
 
 
         Args:
         Args:
-            batch (dict): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
+            batch (Dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
             ni (int): Batch index used for naming the output file.
             ni (int): Batch index used for naming the output file.
 
 
         Examples:
         Examples:
@@ -175,21 +176,20 @@ class ClassificationValidator(BaseValidator):
             >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
             >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
             >>> validator.plot_val_samples(batch, 0)
             >>> validator.plot_val_samples(batch, 0)
         """
         """
+        batch["batch_idx"] = torch.arange(len(batch["img"]))  # add batch index for plotting
         plot_images(
         plot_images(
-            images=batch["img"],
-            batch_idx=torch.arange(len(batch["img"])),
-            cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
+            labels=batch,
             fname=self.save_dir / f"val_batch{ni}_labels.jpg",
             fname=self.save_dir / f"val_batch{ni}_labels.jpg",
             names=self.names,
             names=self.names,
             on_plot=self.on_plot,
             on_plot=self.on_plot,
         )
         )
 
 
-    def plot_predictions(self, batch, preds, ni):
+    def plot_predictions(self, batch: Dict[str, Any], preds: torch.Tensor, ni: int) -> None:
         """
         """
         Plot images with their predicted class labels and save the visualization.
         Plot images with their predicted class labels and save the visualization.
 
 
         Args:
         Args:
-            batch (dict): Batch data containing images and other information.
+            batch (Dict[str, Any]): Batch data containing images and other information.
             preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
             preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
             ni (int): Batch index used for naming the output file.
             ni (int): Batch index used for naming the output file.
 
 
@@ -199,10 +199,13 @@ class ClassificationValidator(BaseValidator):
             >>> preds = torch.rand(16, 10)  # 16 images, 10 classes
             >>> preds = torch.rand(16, 10)  # 16 images, 10 classes
             >>> validator.plot_predictions(batch, preds, 0)
             >>> validator.plot_predictions(batch, preds, 0)
         """
         """
-        plot_images(
-            batch["img"],
+        batched_preds = dict(
+            img=batch["img"],
             batch_idx=torch.arange(len(batch["img"])),
             batch_idx=torch.arange(len(batch["img"])),
             cls=torch.argmax(preds, dim=1),
             cls=torch.argmax(preds, dim=1),
+        )
+        plot_images(
+            batched_preds,
             fname=self.save_dir / f"val_batch{ni}_pred.jpg",
             fname=self.save_dir / f"val_batch{ni}_pred.jpg",
             names=self.names,
             names=self.names,
             on_plot=self.on_plot,
             on_plot=self.on_plot,
Discard
@@ -3,7 +3,7 @@
 import math
 import math
 import random
 import random
 from copy import copy
 from copy import copy
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
 
 
 import numpy as np
 import numpy as np
 import torch.nn as nn
 import torch.nn as nn
@@ -178,19 +178,16 @@ class DetectionTrainer(BaseTrainer):
             "Size",
             "Size",
         )
         )
 
 
-    def plot_training_samples(self, batch: Dict, ni: int):
+    def plot_training_samples(self, batch: Dict[str, Any], ni: int) -> None:
         """
         """
         Plot training samples with their annotations.
         Plot training samples with their annotations.
 
 
         Args:
         Args:
-            batch (Dict): Dictionary containing batch data.
+            batch (Dict[str, Any]): Dictionary containing batch data.
             ni (int): Number of iterations.
             ni (int): Number of iterations.
         """
         """
         plot_images(
         plot_images(
-            images=batch["img"],
-            batch_idx=batch["batch_idx"],
-            cls=batch["cls"].squeeze(-1),
-            bboxes=batch["bboxes"],
+            labels=batch,
             paths=batch["im_file"],
             paths=batch["im_file"],
             fname=self.save_dir / f"train_batch{ni}.jpg",
             fname=self.save_dir / f"train_batch{ni}.jpg",
             on_plot=self.on_plot,
             on_plot=self.on_plot,
Discard
@@ -12,7 +12,7 @@ from ultralytics.engine.validator import BaseValidator
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils.checks import check_requirements
 from ultralytics.utils.checks import check_requirements
 from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
 from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
-from ultralytics.utils.plotting import output_to_target, plot_images
+from ultralytics.utils.plotting import plot_images
 
 
 
 
 class DetectionValidator(BaseValidator):
 class DetectionValidator(BaseValidator):
@@ -23,8 +23,6 @@ class DetectionValidator(BaseValidator):
     prediction processing, and visualization of results.
     prediction processing, and visualization of results.
 
 
     Attributes:
     Attributes:
-        nt_per_class (np.ndarray): Number of targets per class.
-        nt_per_image (np.ndarray): Number of targets per image.
         is_coco (bool): Whether the dataset is COCO.
         is_coco (bool): Whether the dataset is COCO.
         is_lvis (bool): Whether the dataset is LVIS.
         is_lvis (bool): Whether the dataset is LVIS.
         class_map (List[int]): Mapping from model class indices to dataset class indices.
         class_map (List[int]): Mapping from model class indices to dataset class indices.
@@ -53,15 +51,13 @@ class DetectionValidator(BaseValidator):
             _callbacks (List[Any], optional): List of callback functions.
             _callbacks (List[Any], optional): List of callback functions.
         """
         """
         super().__init__(dataloader, save_dir, args, _callbacks)
         super().__init__(dataloader, save_dir, args, _callbacks)
-        self.nt_per_class = None
-        self.nt_per_image = None
         self.is_coco = False
         self.is_coco = False
         self.is_lvis = False
         self.is_lvis = False
         self.class_map = None
         self.class_map = None
         self.args.task = "detect"
         self.args.task = "detect"
-        self.metrics = DetMetrics(save_dir=self.save_dir)
         self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95
         self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95
         self.niou = self.iouv.numel()
         self.niou = self.iouv.numel()
+        self.metrics = DetMetrics()
 
 
     def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
     def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
@@ -99,18 +95,16 @@ class DetectionValidator(BaseValidator):
         self.names = model.names
         self.names = model.names
         self.nc = len(model.names)
         self.nc = len(model.names)
         self.end2end = getattr(model, "end2end", False)
         self.end2end = getattr(model, "end2end", False)
-        self.metrics.names = self.names
-        self.metrics.plot = self.args.plots
-        self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, names=self.names.values())
         self.seen = 0
         self.seen = 0
         self.jdict = []
         self.jdict = []
-        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+        self.metrics.names = self.names
+        self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
 
 
     def get_desc(self) -> str:
     def get_desc(self) -> str:
         """Return a formatted string summarizing class metrics of YOLO model."""
         """Return a formatted string summarizing class metrics of YOLO model."""
         return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
         return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
 
 
-    def postprocess(self, preds: torch.Tensor) -> List[torch.Tensor]:
+    def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:
         """
         """
         Apply Non-maximum suppression to prediction outputs.
         Apply Non-maximum suppression to prediction outputs.
 
 
@@ -118,9 +112,10 @@ class DetectionValidator(BaseValidator):
             preds (torch.Tensor): Raw predictions from the model.
             preds (torch.Tensor): Raw predictions from the model.
 
 
         Returns:
         Returns:
-            (List[torch.Tensor]): Processed predictions after NMS.
+            (List[Dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
+                'bboxes', 'conf', 'cls', and 'extra' tensors.
         """
         """
-        return ops.non_max_suppression(
+        outputs = ops.non_max_suppression(
             preds,
             preds,
             self.args.conf,
             self.args.conf,
             self.args.iou,
             self.args.iou,
@@ -131,6 +126,7 @@ class DetectionValidator(BaseValidator):
             end2end=self.end2end,
             end2end=self.end2end,
             rotated=self.args.task == "obb",
             rotated=self.args.task == "obb",
         )
         )
+        return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
 
 
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
@@ -152,68 +148,60 @@ class DetectionValidator(BaseValidator):
         if len(cls):
         if len(cls):
             bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # target boxes
             bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # target boxes
             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # native-space labels
             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # native-space labels
-        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+        return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
 
 
-    def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         """
         Prepare predictions for evaluation against ground truth.
         Prepare predictions for evaluation against ground truth.
 
 
         Args:
         Args:
-            pred (torch.Tensor): Model predictions.
+            pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
             pbatch (Dict[str, Any]): Prepared batch information.
             pbatch (Dict[str, Any]): Prepared batch information.
 
 
         Returns:
         Returns:
-            (torch.Tensor): Prepared predictions in native space.
-        """
-        predn = pred.clone()
-        ops.scale_boxes(
-            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+            (Dict[str, torch.Tensor]): Prepared predictions in native space.
+        """
+        cls = pred["cls"]
+        if self.args.single_cls:
+            cls *= 0
+        # predn = pred.clone()
+        bboxes = ops.scale_boxes(
+            pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
         )  # native-space pred
         )  # native-space pred
-        return predn
+        return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
 
 
-    def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
+    def update_metrics(self, preds: List[Dict[str, torch.Tensor]], batch: Dict[str, Any]) -> None:
         """
         """
         Update metrics with new predictions and ground truth.
         Update metrics with new predictions and ground truth.
 
 
         Args:
         Args:
-            preds (List[torch.Tensor]): List of predictions from the model.
+            preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
             batch (Dict[str, Any]): Batch data containing ground truth.
             batch (Dict[str, Any]): Batch data containing ground truth.
         """
         """
         for si, pred in enumerate(preds):
         for si, pred in enumerate(preds):
             self.seen += 1
             self.seen += 1
-            npr = len(pred)
-            stat = dict(
-                conf=torch.zeros(0, device=self.device),
-                pred_cls=torch.zeros(0, device=self.device),
-                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-            )
             pbatch = self._prepare_batch(si, batch)
             pbatch = self._prepare_batch(si, batch)
-            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
-            nl = len(cls)
-            stat["target_cls"] = cls
-            stat["target_img"] = cls.unique()
-            if npr == 0:
-                if nl:
-                    for k in self.stats.keys():
-                        self.stats[k].append(stat[k])
-                    if self.args.plots:
-                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
-                continue
-
-            # Predictions
-            if self.args.single_cls:
-                pred[:, 5] = 0
             predn = self._prepare_pred(pred, pbatch)
             predn = self._prepare_pred(pred, pbatch)
-            stat["conf"] = predn[:, 4]
-            stat["pred_cls"] = predn[:, 5]
 
 
+            cls = pbatch["cls"].cpu().numpy()
+            no_pred = len(predn["cls"]) == 0
+            if no_pred and len(cls) == 0:
+                continue
+            self.metrics.update_stats(
+                {
+                    **self._process_batch(predn, pbatch),
+                    "target_cls": cls,
+                    "target_img": np.unique(cls),
+                    "conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
+                    "pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
+                }
+            )
             # Evaluate
             # Evaluate
-            if nl:
-                stat["tp"] = self._process_batch(predn, bbox, cls)
             if self.args.plots:
             if self.args.plots:
-                self.confusion_matrix.process_batch(predn, bbox, cls)
-            for k in self.stats.keys():
-                self.stats[k].append(stat[k])
+                self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
+
+            if no_pred:
+                continue
 
 
             # Save
             # Save
             if self.args.save_json:
             if self.args.save_json:
@@ -241,44 +229,45 @@ class DetectionValidator(BaseValidator):
         Returns:
         Returns:
             (Dict[str, Any]): Dictionary containing metrics results.
             (Dict[str, Any]): Dictionary containing metrics results.
         """
         """
-        stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy
-        self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
-        self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
-        stats.pop("target_img", None)
-        if len(stats):
-            self.metrics.process(**stats, on_plot=self.on_plot)
+        self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
+        self.metrics.clear_stats()
         return self.metrics.results_dict
         return self.metrics.results_dict
 
 
     def print_results(self) -> None:
     def print_results(self) -> None:
         """Print training/validation set metrics per class."""
         """Print training/validation set metrics per class."""
         pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format
         pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format
-        LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
-        if self.nt_per_class.sum() == 0:
+        LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
+        if self.metrics.nt_per_class.sum() == 0:
             LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
             LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
 
 
         # Print results per class
         # Print results per class
-        if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
+        if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
             for i, c in enumerate(self.metrics.ap_class_index):
             for i, c in enumerate(self.metrics.ap_class_index):
                 LOGGER.info(
                 LOGGER.info(
-                    pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
+                    pf
+                    % (
+                        self.names[c],
+                        self.metrics.nt_per_image[c],
+                        self.metrics.nt_per_class[c],
+                        *self.metrics.class_result(i),
+                    )
                 )
                 )
 
 
-    def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
         """
         """
         Return correct prediction matrix.
         Return correct prediction matrix.
 
 
         Args:
         Args:
-            detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
-                (x1, y1, x2, y2, conf, class).
-            gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
-                bounding box is of the format: (x1, y1, x2, y2).
-            gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
+            preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
+            batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
 
 
         Returns:
         Returns:
-            (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
+            (Dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
         """
         """
-        iou = box_iou(gt_bboxes, detections[:, :4])
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
+        if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
+            return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
+        iou = box_iou(batch["bboxes"], preds["bboxes"])
+        return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
 
 
     def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None) -> torch.utils.data.Dataset:
     def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None) -> torch.utils.data.Dataset:
         """
         """
@@ -317,42 +306,50 @@ class DetectionValidator(BaseValidator):
             ni (int): Batch index.
             ni (int): Batch index.
         """
         """
         plot_images(
         plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
+            labels=batch,
             paths=batch["im_file"],
             paths=batch["im_file"],
             fname=self.save_dir / f"val_batch{ni}_labels.jpg",
             fname=self.save_dir / f"val_batch{ni}_labels.jpg",
             names=self.names,
             names=self.names,
             on_plot=self.on_plot,
             on_plot=self.on_plot,
         )
         )
 
 
-    def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
+    def plot_predictions(
+        self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None
+    ) -> None:
         """
         """
         Plot predicted bounding boxes on input images and save the result.
         Plot predicted bounding boxes on input images and save the result.
 
 
         Args:
         Args:
             batch (Dict[str, Any]): Batch containing images and annotations.
             batch (Dict[str, Any]): Batch containing images and annotations.
-            preds (List[torch.Tensor]): List of predictions from the model.
+            preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
             ni (int): Batch index.
             ni (int): Batch index.
-        """
+            max_det (Optional[int]): Maximum number of detections to plot.
+        """
+        # TODO: optimize this
+        for i, pred in enumerate(preds):
+            pred["batch_idx"] = torch.ones_like(pred["conf"]) * i  # add batch index to predictions
+        keys = preds[0].keys()
+        max_det = max_det or self.args.max_det
+        batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
+        # TODO: fix this
+        batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4])  # convert to xywh format
         plot_images(
         plot_images(
-            batch["img"],
-            *output_to_target(preds, max_det=self.args.max_det),
+            images=batch["img"],
+            labels=batched_preds,
             paths=batch["im_file"],
             paths=batch["im_file"],
             fname=self.save_dir / f"val_batch{ni}_pred.jpg",
             fname=self.save_dir / f"val_batch{ni}_pred.jpg",
             names=self.names,
             names=self.names,
             on_plot=self.on_plot,
             on_plot=self.on_plot,
         )  # pred
         )  # pred
 
 
-    def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
+    def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         """
         Save YOLO detections to a txt file in normalized coordinates in a specific format.
         Save YOLO detections to a txt file in normalized coordinates in a specific format.
 
 
         Args:
         Args:
-            predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
+            predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
             save_conf (bool): Whether to save confidence scores.
             save_conf (bool): Whether to save confidence scores.
-            shape (Tuple[int, int]): Shape of the original image.
+            shape (Tuple[int, int]): Shape of the original image (height, width).
             file (Path): File path to save the detections.
             file (Path): File path to save the detections.
         """
         """
         from ultralytics.engine.results import Results
         from ultralytics.engine.results import Results
@@ -361,28 +358,29 @@ class DetectionValidator(BaseValidator):
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             path=None,
             names=self.names,
             names=self.names,
-            boxes=predn[:, :6],
+            boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
         ).save_txt(file, save_conf=save_conf)
         ).save_txt(file, save_conf=save_conf)
 
 
-    def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
+    def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
         """
         """
         Serialize YOLO predictions to COCO json format.
         Serialize YOLO predictions to COCO json format.
 
 
         Args:
         Args:
-            predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
+            predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
+                with bounding box coordinates, confidence scores, and class predictions.
             filename (str): Image filename.
             filename (str): Image filename.
         """
         """
         stem = Path(filename).stem
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
         image_id = int(stem) if stem.isnumeric() else stem
-        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box = ops.xyxy2xywh(predn["bboxes"])  # xywh
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
-        for p, b in zip(predn.tolist(), box.tolist()):
+        for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
             self.jdict.append(
             self.jdict.append(
                 {
                 {
                     "image_id": image_id,
                     "image_id": image_id,
-                    "category_id": self.class_map[int(p[5])],
+                    "category_id": self.class_map[int(c)],
                     "bbox": [round(x, 3) for x in b],
                     "bbox": [round(x, 3) for x in b],
-                    "score": round(p[4], 5),
+                    "score": round(s, 5),
                 }
                 }
             )
             )
 
 
Discard
@@ -3,12 +3,12 @@
 from pathlib import Path
 from pathlib import Path
 from typing import Any, Dict, List, Tuple, Union
 from typing import Any, Dict, List, Tuple, Union
 
 
+import numpy as np
 import torch
 import torch
 
 
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils.metrics import OBBMetrics, batch_probiou
 from ultralytics.utils.metrics import OBBMetrics, batch_probiou
-from ultralytics.utils.plotting import output_to_rotated_target, plot_images
 
 
 
 
 class OBBValidator(DetectionValidator):
 class OBBValidator(DetectionValidator):
@@ -55,7 +55,7 @@ class OBBValidator(DetectionValidator):
         """
         """
         super().__init__(dataloader, save_dir, args, _callbacks)
         super().__init__(dataloader, save_dir, args, _callbacks)
         self.args.task = "obb"
         self.args.task = "obb"
-        self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
+        self.metrics = OBBMetrics()
 
 
     def init_metrics(self, model: torch.nn.Module) -> None:
     def init_metrics(self, model: torch.nn.Module) -> None:
         """
         """
@@ -68,20 +68,20 @@ class OBBValidator(DetectionValidator):
         val = self.data.get(self.args.split, "")  # validation path
         val = self.data.get(self.args.split, "")  # validation path
         self.is_dota = isinstance(val, str) and "DOTA" in val  # check if dataset is DOTA format
         self.is_dota = isinstance(val, str) and "DOTA" in val  # check if dataset is DOTA format
 
 
-    def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
         """
         """
         Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
         Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
 
 
         Args:
         Args:
-            detections (torch.Tensor): Detected bounding boxes and associated data with shape (N, 7) where each
-                detection is represented as (x1, y1, x2, y2, conf, class, angle).
-            gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (M, 5) where each box is represented
-                as (x1, y1, x2, y2, angle).
-            gt_cls (torch.Tensor): Class labels for the ground truth bounding boxes with shape (M,).
+            preds (Dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
+                class labels and bounding boxes.
+            batch (Dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
+                class labels and bounding boxes.
 
 
         Returns:
         Returns:
-            (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU levels for each
-                detection, indicating the accuracy of predictions compared to the ground truth.
+            (Dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
+                array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
+                of predictions compared to the ground truth.
 
 
         Examples:
         Examples:
             >>> detections = torch.rand(100, 7)  # 100 sample detections
             >>> detections = torch.rand(100, 7)  # 100 sample detections
@@ -89,10 +89,25 @@ class OBBValidator(DetectionValidator):
             >>> gt_cls = torch.randint(0, 5, (50,))  # 50 ground truth class labels
             >>> gt_cls = torch.randint(0, 5, (50,))  # 50 ground truth class labels
             >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
             >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
         """
         """
-        iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
+        if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
+            return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
+        iou = batch_probiou(batch["bboxes"], preds["bboxes"])
+        return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
 
 
-    def _prepare_batch(self, si: int, batch: Dict) -> Dict:
+    def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:
+        """
+        Args:
+            preds (torch.Tensor): Raw predictions from the model.
+
+        Returns:
+            (List[Dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
+        """
+        preds = super().postprocess(preds)
+        for pred in preds:
+            pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1)  # concatenate angle
+        return preds
+
+    def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
         Prepare batch data for OBB validation with proper scaling and formatting.
         Prepare batch data for OBB validation with proper scaling and formatting.
 
 
@@ -118,9 +133,9 @@ class OBBValidator(DetectionValidator):
         if len(cls):
         if len(cls):
             bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]])  # target boxes
             bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]])  # target boxes
             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True)  # native-space labels
             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True)  # native-space labels
-        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+        return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
 
 
-    def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         """
         Prepare predictions by scaling bounding boxes to original image dimensions.
         Prepare predictions by scaling bounding boxes to original image dimensions.
 
 
@@ -128,20 +143,22 @@ class OBBValidator(DetectionValidator):
         input dimensions to the original image dimensions using the provided batch information.
         input dimensions to the original image dimensions using the provided batch information.
 
 
         Args:
         Args:
-            pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
+            pred (Dict[str, torch.Tensor]): Prediction dictionary containing bounding box coordinates and other information.
             pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
             pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
                 - imgsz (tuple): Model input image size.
                 - imgsz (tuple): Model input image size.
                 - ori_shape (tuple): Original image shape.
                 - ori_shape (tuple): Original image shape.
                 - ratio_pad (tuple): Ratio and padding information for scaling.
                 - ratio_pad (tuple): Ratio and padding information for scaling.
 
 
         Returns:
         Returns:
-            (torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
+            (Dict[str, torch.Tensor]): Scaled prediction dictionary with bounding boxes in original image dimensions.
         """
         """
-        predn = pred.clone()
-        ops.scale_boxes(
-            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
+        cls = pred["cls"]
+        if self.args.single_cls:
+            cls *= 0
+        bboxes = ops.scale_boxes(
+            pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
         )  # native-space pred
         )  # native-space pred
-        return predn
+        return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
 
 
     def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
     def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
         """
         """
@@ -158,22 +175,18 @@ class OBBValidator(DetectionValidator):
             >>> preds = [torch.rand(10, 7)]  # Example predictions for one image
             >>> preds = [torch.rand(10, 7)]  # Example predictions for one image
             >>> validator.plot_predictions(batch, preds, 0)
             >>> validator.plot_predictions(batch, preds, 0)
         """
         """
-        plot_images(
-            batch["img"],
-            *output_to_rotated_target(preds, max_det=self.args.max_det),
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )  # pred
+        for p in preds:
+            # TODO: fix this duplicated `xywh2xyxy`
+            p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4])  # convert to xyxy format for plotting
+        super().plot_predictions(batch, preds, ni)  # plot bboxes
 
 
-    def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]) -> None:
+    def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: Union[str, Path]) -> None:
         """
         """
         Convert YOLO predictions to COCO JSON format with rotated bounding box information.
         Convert YOLO predictions to COCO JSON format with rotated bounding box information.
 
 
         Args:
         Args:
-            predn (torch.Tensor): Prediction tensor containing bounding box coordinates, confidence scores,
-                class predictions, and rotation angles with shape (N, 6+) where the last column is the angle.
+            predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
+                with bounding box coordinates, confidence scores, and class predictions.
             filename (str | Path): Path to the image file for which predictions are being processed.
             filename (str | Path): Path to the image file for which predictions are being processed.
 
 
         Notes:
         Notes:
@@ -183,22 +196,20 @@ class OBBValidator(DetectionValidator):
         """
         """
         stem = Path(filename).stem
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
         image_id = int(stem) if stem.isnumeric() else stem
-        rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
+        rbox = predn["bboxes"]
         poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
         poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
-        for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
+        for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
             self.jdict.append(
             self.jdict.append(
                 {
                 {
                     "image_id": image_id,
                     "image_id": image_id,
-                    "category_id": self.class_map[int(predn[i, 5].item())],
-                    "score": round(predn[i, 4].item(), 5),
+                    "category_id": self.class_map[int(c)],
+                    "score": round(s, 5),
                     "rbox": [round(x, 3) for x in r],
                     "rbox": [round(x, 3) for x in r],
                     "poly": [round(x, 3) for x in b],
                     "poly": [round(x, 3) for x in b],
                 }
                 }
             )
             )
 
 
-    def save_one_txt(
-        self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]
-    ) -> None:
+    def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         """
         Save YOLO OBB detections to a text file in normalized coordinates.
         Save YOLO OBB detections to a text file in normalized coordinates.
 
 
@@ -207,7 +218,7 @@ class OBBValidator(DetectionValidator):
                 class predictions, and angles in format (x, y, w, h, conf, cls, angle).
                 class predictions, and angles in format (x, y, w, h, conf, cls, angle).
             save_conf (bool): Whether to save confidence scores in the text file.
             save_conf (bool): Whether to save confidence scores in the text file.
             shape (Tuple[int, int]): Original image shape in format (height, width).
             shape (Tuple[int, int]): Original image shape in format (height, width).
-            file (Path | str): Output file path to save detections.
+            file (Path): Output file path to save detections.
 
 
         Examples:
         Examples:
             >>> validator = OBBValidator()
             >>> validator = OBBValidator()
@@ -218,14 +229,11 @@ class OBBValidator(DetectionValidator):
 
 
         from ultralytics.engine.results import Results
         from ultralytics.engine.results import Results
 
 
-        rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
-        # xywh, r, conf, cls
-        obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
         Results(
         Results(
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             path=None,
             names=self.names,
             names=self.names,
-            obb=obb,
+            obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
         ).save_txt(file, save_conf=save_conf)
         ).save_txt(file, save_conf=save_conf)
 
 
     def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
     def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
Discard
@@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
 from ultralytics.models import yolo
 from ultralytics.models import yolo
 from ultralytics.nn.tasks import PoseModel
 from ultralytics.nn.tasks import PoseModel
 from ultralytics.utils import DEFAULT_CFG, LOGGER
 from ultralytics.utils import DEFAULT_CFG, LOGGER
-from ultralytics.utils.plotting import plot_images, plot_results
+from ultralytics.utils.plotting import plot_results
 
 
 
 
 class PoseTrainer(yolo.detect.DetectionTrainer):
 class PoseTrainer(yolo.detect.DetectionTrainer):
@@ -108,40 +108,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
         )
         )
 
 
-    def plot_training_samples(self, batch: Dict[str, Any], ni: int):
-        """
-        Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
-
-        Args:
-            batch (dict): Dictionary containing batch data with the following keys:
-                - img (torch.Tensor): Batch of images
-                - keypoints (torch.Tensor): Keypoints coordinates for pose estimation
-                - cls (torch.Tensor): Class labels
-                - bboxes (torch.Tensor): Bounding box coordinates
-                - im_file (list): List of image file paths
-                - batch_idx (torch.Tensor): Batch indices for each instance
-            ni (int): Current training iteration number used for filename
-
-        The function saves the plotted batch as an image in the trainer's save directory with the filename
-        'train_batch{ni}.jpg', where ni is the iteration number.
-        """
-        images = batch["img"]
-        kpts = batch["keypoints"]
-        cls = batch["cls"].squeeze(-1)
-        bboxes = batch["bboxes"]
-        paths = batch["im_file"]
-        batch_idx = batch["batch_idx"]
-        plot_images(
-            images,
-            batch_idx,
-            cls,
-            bboxes,
-            kpts=kpts,
-            paths=paths,
-            fname=self.save_dir / f"train_batch{ni}.jpg",
-            on_plot=self.on_plot,
-        )
-
     def plot_metrics(self):
     def plot_metrics(self):
         """Plot training/validation metrics."""
         """Plot training/validation metrics."""
         plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png
         plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png
Discard
@@ -1,7 +1,7 @@
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
 
 
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Tuple
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
@@ -9,8 +9,7 @@ import torch
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils import LOGGER, ops
 from ultralytics.utils.checks import check_requirements
 from ultralytics.utils.checks import check_requirements
-from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
-from ultralytics.utils.plotting import output_to_target, plot_images
+from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
 
 
 
 
 class PoseValidator(DetectionValidator):
 class PoseValidator(DetectionValidator):
@@ -33,7 +32,6 @@ class PoseValidator(DetectionValidator):
         _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
         _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
             dimensions.
             dimensions.
         _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
         _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
-        update_metrics: Update metrics with new predictions and ground truth data.
         _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
         _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
             detections and ground truth.
             detections and ground truth.
         plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
         plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
@@ -77,7 +75,7 @@ class PoseValidator(DetectionValidator):
         self.sigma = None
         self.sigma = None
         self.kpt_shape = None
         self.kpt_shape = None
         self.args.task = "pose"
         self.args.task = "pose"
-        self.metrics = PoseMetrics(save_dir=self.save_dir)
+        self.metrics = PoseMetrics()
         if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
         if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
             LOGGER.warning(
             LOGGER.warning(
                 "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                 "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
@@ -118,7 +116,36 @@ class PoseValidator(DetectionValidator):
         is_pose = self.kpt_shape == [17, 3]
         is_pose = self.kpt_shape == [17, 3]
         nkpt = self.kpt_shape[0]
         nkpt = self.kpt_shape[0]
         self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
         self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
-        self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+
+    def postprocess(self, preds: torch.Tensor) -> Dict[str, torch.Tensor]:
+        """
+        Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
+
+        This method extends the parent class postprocessing by extracting keypoints from the 'extra'
+        field of predictions and reshaping them according to the keypoint shape configuration.
+        The keypoints are reshaped from a flattened format to the proper dimensional structure
+        (typically [N, 17, 3] for COCO pose format).
+
+        Args:
+            preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
+                bounding boxes, confidence scores, class predictions, and keypoint data.
+
+        Returns:
+            (Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
+                - 'bboxes': Bounding box coordinates
+                - 'conf': Confidence scores
+                - 'cls': Class predictions
+                - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
+
+        Note:
+            If no keypoints are present in a prediction (empty keypoints), that prediction
+            is skipped and continues to the next one. The keypoints are extracted from the
+            'extra' field which contains additional task-specific data beyond basic detection.
+        """
+        preds = super().postprocess(preds)
+        for pred in preds:
+            pred["keypoints"] = pred.pop("extra").reshape(-1, *self.kpt_shape)  # remove extra if exists
+        return preds
 
 
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
@@ -142,10 +169,10 @@ class PoseValidator(DetectionValidator):
         kpts[..., 0] *= w
         kpts[..., 0] *= w
         kpts[..., 1] *= h
         kpts[..., 1] *= h
         kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
         kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
-        pbatch["kpts"] = kpts
+        pbatch["keypoints"] = kpts
         return pbatch
         return pbatch
 
 
-    def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
+    def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
         Prepare and scale keypoints in predictions for pose processing.
         Prepare and scale keypoints in predictions for pose processing.
 
 
@@ -154,189 +181,59 @@ class PoseValidator(DetectionValidator):
         to match the original image dimensions.
         to match the original image dimensions.
 
 
         Args:
         Args:
-            pred (torch.Tensor): Raw prediction tensor from the model.
+            pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
             pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
             pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
                 - imgsz: Image size used for inference
                 - imgsz: Image size used for inference
                 - ori_shape: Original image shape
                 - ori_shape: Original image shape
                 - ratio_pad: Ratio and padding information for coordinate scaling
                 - ratio_pad: Ratio and padding information for coordinate scaling
 
 
         Returns:
         Returns:
-            predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
-            pred_kpts (torch.Tensor): Predicted keypoints scaled to original image dimensions.
+            (Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.
         """
         """
         predn = super()._prepare_pred(pred, pbatch)
         predn = super()._prepare_pred(pred, pbatch)
-        nk = pbatch["kpts"].shape[1]
-        pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
-        ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
-        return predn, pred_kpts
-
-    def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
-        """
-        Update metrics with new predictions and ground truth data.
-
-        This method processes each prediction, compares it with ground truth, and updates various statistics
-        for performance evaluation.
+        predn["keypoints"] = ops.scale_coords(
+            pbatch["imgsz"], pred.get("keypoints").clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+        )
+        return predn
 
 
-        Args:
-            preds (List[torch.Tensor]): List of prediction tensors from the model.
-            batch (Dict[str, Any]): Batch data containing images and ground truth annotations.
-        """
-        for si, pred in enumerate(preds):
-            self.seen += 1
-            npr = len(pred)
-            stat = dict(
-                conf=torch.zeros(0, device=self.device),
-                pred_cls=torch.zeros(0, device=self.device),
-                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-                tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-            )
-            pbatch = self._prepare_batch(si, batch)
-            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
-            nl = len(cls)
-            stat["target_cls"] = cls
-            stat["target_img"] = cls.unique()
-            if npr == 0:
-                if nl:
-                    for k in self.stats.keys():
-                        self.stats[k].append(stat[k])
-                    if self.args.plots:
-                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
-                continue
-
-            # Predictions
-            if self.args.single_cls:
-                pred[:, 5] = 0
-            predn, pred_kpts = self._prepare_pred(pred, pbatch)
-            stat["conf"] = predn[:, 4]
-            stat["pred_cls"] = predn[:, 5]
-
-            # Evaluate
-            if nl:
-                stat["tp"] = self._process_batch(predn, bbox, cls)
-                stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
-            if self.args.plots:
-                self.confusion_matrix.process_batch(predn, bbox, cls)
-
-            for k in self.stats.keys():
-                self.stats[k].append(stat[k])
-
-            # Save
-            if self.args.save_json:
-                self.pred_to_json(predn, batch["im_file"][si])
-            if self.args.save_txt:
-                self.save_one_txt(
-                    predn,
-                    pred_kpts,
-                    self.args.save_conf,
-                    pbatch["ori_shape"],
-                    self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
-                )
-
-    def _process_batch(
-        self,
-        detections: torch.Tensor,
-        gt_bboxes: torch.Tensor,
-        gt_cls: torch.Tensor,
-        pred_kpts: Optional[torch.Tensor] = None,
-        gt_kpts: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
         """
         """
         Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
         Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
 
 
         Args:
         Args:
-            detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
-                detection is of the format (x1, y1, x2, y2, conf, class).
-            gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
-                box is of the format (x1, y1, x2, y2).
-            gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
-            pred_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing predicted keypoints, where
-                51 corresponds to 17 keypoints each having 3 values.
-            gt_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing ground truth keypoints.
+            preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
+                and 'keypoints' for keypoint predictions.
+            batch (Dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
+                'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
 
 
         Returns:
         Returns:
-            (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
-                where N is the number of detections.
+            (Dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
+                true positives across 10 IoU levels.
 
 
         Notes:
         Notes:
             `0.53` scale factor used in area computation is referenced from
             `0.53` scale factor used in area computation is referenced from
             https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
             https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
         """
         """
-        if pred_kpts is not None and gt_kpts is not None:
+        tp = super()._process_batch(preds, batch)
+        gt_cls = batch["cls"]
+        if len(gt_cls) == 0 or len(preds["cls"]) == 0:
+            tp_p = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
+        else:
             # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
             # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
-            area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
-            iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
-        else:  # boxes
-            iou = box_iou(gt_bboxes, detections[:, :4])
+            area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
+            iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
+            tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
+        tp.update({"tp_p": tp_p})  # update tp with kpts IoU
+        return tp
 
 
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
-
-    def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
-        """
-        Plot and save validation set samples with ground truth bounding boxes and keypoints.
-
-        Args:
-            batch (Dict[str, Any]): Dictionary containing batch data with keys:
-                - img (torch.Tensor): Batch of images
-                - batch_idx (torch.Tensor): Batch indices for each image
-                - cls (torch.Tensor): Class labels
-                - bboxes (torch.Tensor): Bounding box coordinates
-                - keypoints (torch.Tensor): Keypoint coordinates
-                - im_file (list): List of image file paths
-            ni (int): Batch index used for naming the output file
-        """
-        plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
-            kpts=batch["keypoints"],
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )
-
-    def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
-        """
-        Plot and save model predictions with bounding boxes and keypoints.
-
-        Args:
-            batch (Dict[str, Any]): Dictionary containing batch data including images, file paths, and other metadata.
-            preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
-                confidence scores, class predictions, and keypoints.
-            ni (int): Batch index used for naming the output file.
-
-        The function extracts keypoints from predictions, converts predictions to target format, and plots them
-        on the input images. The resulting visualization is saved to the specified save directory.
-        """
-        pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
-        plot_images(
-            batch["img"],
-            *output_to_target(preds, max_det=self.args.max_det),
-            kpts=pred_kpts,
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )  # pred
-
-    def save_one_txt(
-        self,
-        predn: torch.Tensor,
-        pred_kpts: torch.Tensor,
-        save_conf: bool,
-        shape: Tuple[int, int],
-        file: Path,
-    ) -> None:
+    def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         """
         Save YOLO pose detections to a text file in normalized coordinates.
         Save YOLO pose detections to a text file in normalized coordinates.
 
 
         Args:
         Args:
-            predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
-            pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
-                and D is the dimension (typically 3 for x, y, visibility).
+            predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
             save_conf (bool): Whether to save confidence scores.
             save_conf (bool): Whether to save confidence scores.
-            shape (tuple): Original image shape (height, width).
+            shape (Tuple[int, int]): Shape of the original image (height, width).
             file (Path): Output file path to save detections.
             file (Path): Output file path to save detections.
 
 
         Notes:
         Notes:
@@ -349,11 +246,11 @@ class PoseValidator(DetectionValidator):
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             path=None,
             names=self.names,
             names=self.names,
-            boxes=predn[:, :6],
-            keypoints=pred_kpts,
+            boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
+            keypoints=predn["keypoints"],
         ).save_txt(file, save_conf=save_conf)
         ).save_txt(file, save_conf=save_conf)
 
 
-    def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
+    def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
         """
         """
         Convert YOLO predictions to COCO JSON format.
         Convert YOLO predictions to COCO JSON format.
 
 
@@ -361,10 +258,9 @@ class PoseValidator(DetectionValidator):
         to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
         to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
 
 
         Args:
         Args:
-            predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
-                and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
-                keypoints dimension.
-            filename (str | Path): Path to the image file for which predictions are being processed.
+            predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
+                and 'keypoints' tensors.
+            filename (str): Path to the image file for which predictions are being processed.
 
 
         Notes:
         Notes:
             The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
             The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
@@ -373,16 +269,21 @@ class PoseValidator(DetectionValidator):
         """
         """
         stem = Path(filename).stem
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
         image_id = int(stem) if stem.isnumeric() else stem
-        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box = ops.xyxy2xywh(predn["bboxes"])  # xywh
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
-        for p, b in zip(predn.tolist(), box.tolist()):
+        for b, s, c, k in zip(
+            box.tolist(),
+            predn["conf"].tolist(),
+            predn["cls"].tolist(),
+            predn["keypoints"].flatten(1, 2).tolist(),
+        ):
             self.jdict.append(
             self.jdict.append(
                 {
                 {
                     "image_id": image_id,
                     "image_id": image_id,
-                    "category_id": self.class_map[int(p[5])],
+                    "category_id": self.class_map[int(c)],
                     "bbox": [round(x, 3) for x in b],
                     "bbox": [round(x, 3) for x in b],
-                    "keypoints": p[6:],
-                    "score": round(p[4], 5),
+                    "keypoints": k,
+                    "score": round(s, 5),
                 }
                 }
             )
             )
 
 
Discard
@@ -7,7 +7,7 @@ from typing import Dict, Optional, Union
 from ultralytics.models import yolo
 from ultralytics.models import yolo
 from ultralytics.nn.tasks import SegmentationModel
 from ultralytics.nn.tasks import SegmentationModel
 from ultralytics.utils import DEFAULT_CFG, RANK
 from ultralytics.utils import DEFAULT_CFG, RANK
-from ultralytics.utils.plotting import plot_images, plot_results
+from ultralytics.utils.plotting import plot_results
 
 
 
 
 class SegmentationTrainer(yolo.detect.DetectionTrainer):
 class SegmentationTrainer(yolo.detect.DetectionTrainer):
@@ -82,46 +82,6 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
             self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
         )
         )
 
 
-    def plot_training_samples(self, batch: Dict, ni: int):
-        """
-        Plot training sample images with labels, bounding boxes, and masks.
-
-        This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
-        and segmentation masks, saving the result to a file for inspection and debugging.
-
-        Args:
-            batch (dict): Dictionary containing batch data with the following keys:
-                'img': Images tensor
-                'batch_idx': Batch indices for each box
-                'cls': Class labels tensor (squeezed to remove last dimension)
-                'bboxes': Bounding box coordinates tensor
-                'masks': Segmentation masks tensor
-                'im_file': List of image file paths
-            ni (int): Current training iteration number, used for naming the output file.
-
-        Examples:
-            >>> trainer = SegmentationTrainer()
-            >>> batch = {
-            ...     "img": torch.rand(16, 3, 640, 640),
-            ...     "batch_idx": torch.zeros(16),
-            ...     "cls": torch.randint(0, 80, (16, 1)),
-            ...     "bboxes": torch.rand(16, 4),
-            ...     "masks": torch.rand(16, 640, 640),
-            ...     "im_file": ["image1.jpg", "image2.jpg"],
-            ... }
-            >>> trainer.plot_training_samples(batch, ni=5)
-        """
-        plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
-            masks=batch["masks"],
-            paths=batch["im_file"],
-            fname=self.save_dir / f"train_batch{ni}.jpg",
-            on_plot=self.on_plot,
-        )
-
     def plot_metrics(self):
     def plot_metrics(self):
         """Plot training/validation metrics."""
         """Plot training/validation metrics."""
         plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png
         plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png
Discard
@@ -2,7 +2,7 @@
 
 
 from multiprocessing.pool import ThreadPool
 from multiprocessing.pool import ThreadPool
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Tuple
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
@@ -11,8 +11,7 @@ import torch.nn.functional as F
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.models.yolo.detect import DetectionValidator
 from ultralytics.utils import LOGGER, NUM_THREADS, ops
 from ultralytics.utils import LOGGER, NUM_THREADS, ops
 from ultralytics.utils.checks import check_requirements
 from ultralytics.utils.checks import check_requirements
-from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
-from ultralytics.utils.plotting import output_to_target, plot_images
+from ultralytics.utils.metrics import SegmentMetrics, mask_iou
 
 
 
 
 class SegmentationValidator(DetectionValidator):
 class SegmentationValidator(DetectionValidator):
@@ -47,10 +46,9 @@ class SegmentationValidator(DetectionValidator):
             _callbacks (list, optional): List of callback functions.
             _callbacks (list, optional): List of callback functions.
         """
         """
         super().__init__(dataloader, save_dir, args, _callbacks)
         super().__init__(dataloader, save_dir, args, _callbacks)
-        self.plot_masks = None
         self.process = None
         self.process = None
         self.args.task = "segment"
         self.args.task = "segment"
-        self.metrics = SegmentMetrics(save_dir=self.save_dir)
+        self.metrics = SegmentMetrics()
 
 
     def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
     def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
@@ -74,12 +72,10 @@ class SegmentationValidator(DetectionValidator):
             model (torch.nn.Module): Model to validate.
             model (torch.nn.Module): Model to validate.
         """
         """
         super().init_metrics(model)
         super().init_metrics(model)
-        self.plot_masks = []
         if self.args.save_json:
         if self.args.save_json:
             check_requirements("pycocotools>=2.0.6")
             check_requirements("pycocotools>=2.0.6")
         # More accurate vs faster
         # More accurate vs faster
         self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
         self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
-        self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
 
 
     def get_desc(self) -> str:
     def get_desc(self) -> str:
         """Return a formatted description of evaluation metrics."""
         """Return a formatted description of evaluation metrics."""
@@ -97,7 +93,7 @@ class SegmentationValidator(DetectionValidator):
             "mAP50-95)",
             "mAP50-95)",
         )
         )
 
 
-    def postprocess(self, preds: List[torch.Tensor]) -> Tuple[List[torch.Tensor], torch.Tensor]:
+    def postprocess(self, preds: List[torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
         """
         """
         Post-process YOLO predictions and return output detections with proto.
         Post-process YOLO predictions and return output detections with proto.
 
 
@@ -105,12 +101,19 @@ class SegmentationValidator(DetectionValidator):
             preds (List[torch.Tensor]): Raw predictions from the model.
             preds (List[torch.Tensor]): Raw predictions from the model.
 
 
         Returns:
         Returns:
-            p (List[torch.Tensor]): Processed detection predictions.
-            proto (torch.Tensor): Prototype masks for segmentation.
+            List[Dict[str, torch.Tensor]]: Processed detection predictions with masks.
         """
         """
-        p = super().postprocess(preds[0])
         proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
         proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
-        return p, proto
+        preds = super().postprocess(preds[0])
+        imgsz = [4 * x for x in proto.shape[2:]]  # get image size from proto
+        for i, pred in enumerate(preds):
+            coefficient = pred.pop("extra")
+            pred["masks"] = (
+                self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
+                if len(coefficient)
+                else torch.zeros((0, imgsz[0], imgsz[1]), dtype=torch.uint8, device=pred["bboxes"].device)
+            )
+        return preds
 
 
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
     def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
         """
         """
@@ -128,142 +131,56 @@ class SegmentationValidator(DetectionValidator):
         prepared_batch["masks"] = batch["masks"][midx]
         prepared_batch["masks"] = batch["masks"][midx]
         return prepared_batch
         return prepared_batch
 
 
-    def _prepare_pred(
-        self, pred: torch.Tensor, pbatch: Dict[str, Any], proto: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
         """
         """
         Prepare predictions for evaluation by processing bounding boxes and masks.
         Prepare predictions for evaluation by processing bounding boxes and masks.
 
 
         Args:
         Args:
-            pred (torch.Tensor): Raw predictions from the model.
+            pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
             pbatch (Dict[str, Any]): Prepared batch information.
             pbatch (Dict[str, Any]): Prepared batch information.
-            proto (torch.Tensor): Prototype masks for segmentation.
 
 
         Returns:
         Returns:
-            predn (torch.Tensor): Processed bounding box predictions.
-            pred_masks (torch.Tensor): Processed mask predictions.
+            Dict[str, torch.Tensor]: Processed bounding box predictions.
         """
         """
         predn = super()._prepare_pred(pred, pbatch)
         predn = super()._prepare_pred(pred, pbatch)
-        pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
-        return predn, pred_masks
-
-    def update_metrics(self, preds: Tuple[List[torch.Tensor], torch.Tensor], batch: Dict[str, Any]) -> None:
-        """
-        Update metrics with the current batch predictions and targets.
-
-        Args:
-            preds (Tuple[List[torch.Tensor], torch.Tensor]): List of predictions from the model.
-            batch (Dict[str, Any]): Batch data containing ground truth.
-        """
-        for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
-            self.seen += 1
-            npr = len(pred)
-            stat = dict(
-                conf=torch.zeros(0, device=self.device),
-                pred_cls=torch.zeros(0, device=self.device),
-                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
-                tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+        predn["masks"] = pred["masks"]
+        if self.args.save_json and len(predn["masks"]):
+            coco_masks = torch.as_tensor(pred["masks"], dtype=torch.uint8)
+            coco_masks = ops.scale_image(
+                coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
+                pbatch["ori_shape"],
+                ratio_pad=pbatch["ratio_pad"],
             )
             )
-            pbatch = self._prepare_batch(si, batch)
-            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
-            nl = len(cls)
-            stat["target_cls"] = cls
-            stat["target_img"] = cls.unique()
-            if npr == 0:
-                if nl:
-                    for k in self.stats.keys():
-                        self.stats[k].append(stat[k])
-                    if self.args.plots:
-                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
-                continue
-
-            # Masks
-            gt_masks = pbatch.pop("masks")
-            # Predictions
-            if self.args.single_cls:
-                pred[:, 5] = 0
-            predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
-            stat["conf"] = predn[:, 4]
-            stat["pred_cls"] = predn[:, 5]
-
-            # Evaluate
-            if nl:
-                stat["tp"] = self._process_batch(predn, bbox, cls)
-                stat["tp_m"] = self._process_batch(
-                    predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
-                )
-            if self.args.plots:
-                self.confusion_matrix.process_batch(predn, bbox, cls)
-
-            for k in self.stats.keys():
-                self.stats[k].append(stat[k])
-
-            pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
-            if self.args.plots and self.batch_i < 3:
-                self.plot_masks.append(pred_masks[:50].cpu())  # Limit plotted items for speed
-                if pred_masks.shape[0] > 50:
-                    LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
-
-            # Save
-            if self.args.save_json:
-                self.pred_to_json(
-                    predn,
-                    batch["im_file"][si],
-                    ops.scale_image(
-                        pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
-                        pbatch["ori_shape"],
-                        ratio_pad=batch["ratio_pad"][si],
-                    ),
-                )
-            if self.args.save_txt:
-                self.save_one_txt(
-                    predn,
-                    pred_masks,
-                    self.args.save_conf,
-                    pbatch["ori_shape"],
-                    self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
-                )
-
-    def _process_batch(
-        self,
-        detections: torch.Tensor,
-        gt_bboxes: torch.Tensor,
-        gt_cls: torch.Tensor,
-        pred_masks: Optional[torch.Tensor] = None,
-        gt_masks: Optional[torch.Tensor] = None,
-        overlap: Optional[bool] = False,
-        masks: Optional[bool] = False,
-    ) -> torch.Tensor:
+            predn["coco_masks"] = coco_masks
+        return predn
+
+    def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
         """
         """
         Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
         Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
 
 
         Args:
         Args:
-            detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
-                associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
-            gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
-                Each row is of the format [x1, y1, x2, y2].
-            gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
-            pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
-                match the ground truth masks.
-            gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
-            overlap (bool, optional): Flag indicating if overlapping masks should be considered.
-            masks (bool, optional): Flag indicating if the batch contains mask data.
+            preds (Dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
+            batch (Dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
 
 
         Returns:
         Returns:
-            (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
+            (Dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
 
 
         Notes:
         Notes:
             - If `masks` is True, the function computes IoU between predicted and ground truth masks.
             - If `masks` is True, the function computes IoU between predicted and ground truth masks.
             - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
             - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
 
 
         Examples:
         Examples:
-            >>> detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
-            >>> gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
-            >>> gt_cls = torch.tensor([1, 0])
-            >>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
+            >>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
+            >>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
+            >>> correct_preds = validator._process_batch(preds, batch)
         """
         """
-        if masks:
-            if overlap:
+        tp = super()._process_batch(preds, batch)
+        gt_cls, gt_masks = batch["cls"], batch["masks"]
+        if len(gt_cls) == 0 or len(preds["cls"]) == 0:
+            tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
+        else:
+            pred_masks = preds["masks"]
+            if self.args.overlap_mask:
                 nl = len(gt_cls)
                 nl = len(gt_cls)
                 index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
                 index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
                 gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640)
                 gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640)
@@ -272,60 +189,32 @@ class SegmentationValidator(DetectionValidator):
                 gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
                 gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
                 gt_masks = gt_masks.gt_(0.5)
                 gt_masks = gt_masks.gt_(0.5)
             iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
             iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
-        else:  # boxes
-            iou = box_iou(gt_bboxes, detections[:, :4])
-
-        return self.match_predictions(detections[:, 5], gt_cls, iou)
+            tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
+        tp.update({"tp_m": tp_m})  # update tp with mask IoU
+        return tp
 
 
-    def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
-        """
-        Plot validation samples with bounding box labels and masks.
-
-        Args:
-            batch (Dict[str, Any]): Batch containing images and annotations.
-            ni (int): Batch index.
-        """
-        plot_images(
-            batch["img"],
-            batch["batch_idx"],
-            batch["cls"].squeeze(-1),
-            batch["bboxes"],
-            masks=batch["masks"],
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )
-
-    def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
+    def plot_predictions(self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int) -> None:
         """
         """
         Plot batch predictions with masks and bounding boxes.
         Plot batch predictions with masks and bounding boxes.
 
 
         Args:
         Args:
             batch (Dict[str, Any]): Batch containing images and annotations.
             batch (Dict[str, Any]): Batch containing images and annotations.
-            preds (List[torch.Tensor]): List of predictions from the model.
+            preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
             ni (int): Batch index.
             ni (int): Batch index.
         """
         """
-        plot_images(
-            batch["img"],
-            *output_to_target(preds[0], max_det=50),  # not set to self.args.max_det due to slow plotting speed
-            torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
-            paths=batch["im_file"],
-            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
-            names=self.names,
-            on_plot=self.on_plot,
-        )  # pred
-        self.plot_masks.clear()
-
-    def save_one_txt(
-        self, predn: torch.Tensor, pred_masks: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path
-    ) -> None:
+        for p in preds:
+            masks = p["masks"]
+            if masks.shape[0] > 50:
+                LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
+            p["masks"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()
+        super().plot_predictions(batch, preds, ni, max_det=50)  # plot bboxes
+
+    def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
         """
         """
         Save YOLO detections to a txt file in normalized coordinates in a specific format.
         Save YOLO detections to a txt file in normalized coordinates in a specific format.
 
 
         Args:
         Args:
             predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
             predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
-            pred_masks (torch.Tensor): Predicted masks.
             save_conf (bool): Whether to save confidence scores.
             save_conf (bool): Whether to save confidence scores.
             shape (Tuple[int, int]): Shape of the original image.
             shape (Tuple[int, int]): Shape of the original image.
             file (Path): File path to save the detections.
             file (Path): File path to save the detections.
@@ -336,18 +225,17 @@ class SegmentationValidator(DetectionValidator):
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             np.zeros((shape[0], shape[1]), dtype=np.uint8),
             path=None,
             path=None,
             names=self.names,
             names=self.names,
-            boxes=predn[:, :6],
-            masks=pred_masks,
+            boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
+            masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
         ).save_txt(file, save_conf=save_conf)
         ).save_txt(file, save_conf=save_conf)
 
 
-    def pred_to_json(self, predn: torch.Tensor, filename: str, pred_masks: torch.Tensor) -> None:
+    def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
         """
         """
         Save one JSON result for COCO evaluation.
         Save one JSON result for COCO evaluation.
 
 
         Args:
         Args:
-            predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
+            predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
             filename (str): Image filename.
             filename (str): Image filename.
-            pred_masks (numpy.ndarray): Predicted masks.
 
 
         Examples:
         Examples:
              >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
              >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
@@ -362,18 +250,18 @@ class SegmentationValidator(DetectionValidator):
 
 
         stem = Path(filename).stem
         stem = Path(filename).stem
         image_id = int(stem) if stem.isnumeric() else stem
         image_id = int(stem) if stem.isnumeric() else stem
-        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box = ops.xyxy2xywh(predn["bboxes"])  # xywh
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
-        pred_masks = np.transpose(pred_masks, (2, 0, 1))
+        pred_masks = np.transpose(predn["coco_masks"], (2, 0, 1))
         with ThreadPool(NUM_THREADS) as pool:
         with ThreadPool(NUM_THREADS) as pool:
             rles = pool.map(single_encode, pred_masks)
             rles = pool.map(single_encode, pred_masks)
-        for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
+        for i, (b, s, c) in enumerate(zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist())):
             self.jdict.append(
             self.jdict.append(
                 {
                 {
                     "image_id": image_id,
                     "image_id": image_id,
-                    "category_id": self.class_map[int(p[5])],
+                    "category_id": self.class_map[int(c)],
                     "bbox": [round(x, 3) for x in b],
                     "bbox": [round(x, 3) for x in b],
-                    "score": round(p[4], 5),
+                    "score": round(s, 5),
                     "segmentation": rles[i],
                     "segmentation": rles[i],
                 }
                 }
             )
             )
Discard
@@ -457,7 +457,7 @@ def _log_plots(experiment, trainer) -> None:
         >>> _log_plots(experiment, trainer)
         >>> _log_plots(experiment, trainer)
     """
     """
     plot_filenames = None
     plot_filenames = None
-    if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment":
+    if isinstance(trainer.validator.metrics, SegmentMetrics):
         plot_filenames = [
         plot_filenames = [
             trainer.save_dir / f"{prefix}{plots}.png"
             trainer.save_dir / f"{prefix}{plots}.png"
             for plots in EVALUATION_PLOT_NAMES
             for plots in EVALUATION_PLOT_NAMES
Discard
@@ -4,7 +4,7 @@
 import math
 import math
 import warnings
 import warnings
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, List, Tuple, Union
+from typing import Any, Dict, List, Tuple, Union
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
@@ -316,28 +316,22 @@ class ConfusionMatrix(DataExportMixin):
     Attributes:
     Attributes:
         task (str): The type of task, either 'detect' or 'classify'.
         task (str): The type of task, either 'detect' or 'classify'.
         matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
         matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
-        nc (int): The number of classes.
-        conf (float): The confidence threshold for detections.
-        iou_thres (float): The Intersection over Union threshold.
+        nc (int): The number of category.
+        names (List[str]): The names of the classes, used as labels on the plot.
     """
     """
 
 
-    def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, names: tuple = (), task: str = "detect"):
+    def __init__(self, names: List[str] = [], task: str = "detect"):
         """
         """
         Initialize a ConfusionMatrix instance.
         Initialize a ConfusionMatrix instance.
 
 
         Args:
         Args:
-            nc (int): Number of classes.
-            conf (float, optional): Confidence threshold for detections.
-            iou_thres (float, optional): IoU threshold for matching detections to ground truth.
-            names (tuple, optional): Names of classes, used as labels on the plot.
+            names (List[str], optional): Names of classes, used as labels on the plot.
             task (str, optional): Type of task, either 'detect' or 'classify'.
             task (str, optional): Type of task, either 'detect' or 'classify'.
         """
         """
         self.task = task
         self.task = task
-        self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
-        self.nc = nc  # number of classes
-        self.names = list(names)  # name of classes
-        self.conf = 0.25 if conf in {None, 0.001} else conf  # apply 0.25 if default val conf is passed
-        self.iou_thres = iou_thres
+        self.nc = len(names)  # number of classes
+        self.matrix = np.zeros((self.nc + 1, self.nc + 1)) if self.task == "detect" else np.zeros((self.nc, self.nc))
+        self.names = names  # name of classes
 
 
     def process_cls_preds(self, preds, targets):
     def process_cls_preds(self, preds, targets):
         """
         """
@@ -351,41 +345,45 @@ class ConfusionMatrix(DataExportMixin):
         for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
         for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
             self.matrix[p][t] += 1
             self.matrix[p][t] += 1
 
 
-    def process_batch(self, detections, gt_bboxes, gt_cls):
+    def process_batch(
+        self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
+    ) -> None:
         """
         """
         Update confusion matrix for object detection task.
         Update confusion matrix for object detection task.
 
 
         Args:
         Args:
-            detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
-                                      Each row should contain (x1, y1, x2, y2, conf, class)
-                                      or with an additional element `angle` when it's obb.
-            gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
-            gt_cls (Array[M]): The class labels.
+            detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
+                                       Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
+                                       Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
+            batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
+                'cls' (Array[M]) keys, where M is the number of ground truth objects.
+            conf (float, optional): Confidence threshold for detections.
+            iou_thres (float, optional): IoU threshold for matching detections to ground truth.
         """
         """
+        conf = 0.25 if conf in {None, 0.001} else conf  # apply 0.25 if default val conf is passed
+        gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
+        no_pred = len(detections["cls"]) == 0
         if gt_cls.shape[0] == 0:  # Check if labels is empty
         if gt_cls.shape[0] == 0:  # Check if labels is empty
-            if detections is not None:
-                detections = detections[detections[:, 4] > self.conf]
-                detection_classes = detections[:, 5].int().tolist()
+            if not no_pred:
+                detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
+                detection_classes = detections["cls"].int().tolist()
                 for dc in detection_classes:
                 for dc in detection_classes:
                     self.matrix[dc, self.nc] += 1  # false positives
                     self.matrix[dc, self.nc] += 1  # false positives
             return
             return
-        if detections is None:
+        if no_pred:
             gt_classes = gt_cls.int().tolist()
             gt_classes = gt_cls.int().tolist()
             for gc in gt_classes:
             for gc in gt_classes:
                 self.matrix[self.nc, gc] += 1  # background FN
                 self.matrix[self.nc, gc] += 1  # background FN
             return
             return
 
 
-        detections = detections[detections[:, 4] > self.conf]
+        detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
         gt_classes = gt_cls.int().tolist()
         gt_classes = gt_cls.int().tolist()
-        detection_classes = detections[:, 5].int().tolist()
-        is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5  # with additional `angle` dimension
-        iou = (
-            batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
-            if is_obb
-            else box_iou(gt_bboxes, detections[:, :4])
-        )
-
-        x = torch.where(iou > self.iou_thres)
+        detection_classes = detections["cls"].int().tolist()
+        bboxes = detections["bboxes"]
+        is_obb = bboxes.shape[1] == 5  # check if detections contains angle for OBB
+        iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
+
+        x = torch.where(iou > iou_thres)
         if x[0].shape[0]:
         if x[0].shape[0]:
             matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
             matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
             if x[0].shape[0] > 1:
             if x[0].shape[0] > 1:
@@ -949,53 +947,76 @@ class DetMetrics(SimpleClass, DataExportMixin):
     Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
     Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
 
 
     Attributes:
     Attributes:
-        save_dir (Path): A path to the directory where the output plots will be saved.
-        plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
         names (Dict[int, str]): A dictionary of class names.
         names (Dict[int, str]): A dictionary of class names.
         box (Metric): An instance of the Metric class for storing detection results.
         box (Metric): An instance of the Metric class for storing detection results.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'detect'.
         task (str): The task type, set to 'detect'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
     """
     """
 
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         """
         Initialize a DetMetrics instance with a save directory, plot flag, and class names.
         Initialize a DetMetrics instance with a save directory, plot flag, and class names.
 
 
         Args:
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
             names (Dict[int, str], optional): Dictionary of class names.
         """
         """
-        self.save_dir = save_dir
-        self.plot = plot
         self.names = names
         self.names = names
         self.box = Metric()
         self.box = Metric()
         self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
         self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
         self.task = "detect"
         self.task = "detect"
+        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+        self.nt_per_class = None
+        self.nt_per_image = None
 
 
-    def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
+    def update_stats(self, stat: Dict[str, Any]) -> None:
+        """
+        Update statistics by appending new values to existing stat collections.
+
+        Args:
+            stat (Dict[str, any]): Dictionary containing new statistical values to append.
+                         Keys should match existing keys in self.stats.
+        """
+        for k in self.stats.keys():
+            self.stats[k].append(stat[k])
+
+    def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
         """
         """
         Process predicted results for object detection and update metrics.
         Process predicted results for object detection and update metrics.
 
 
         Args:
         Args:
-            tp (np.ndarray): True positive array.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
-            on_plot (callable, optional): Function to call after plots are generated.
+            save_dir (Path): Directory to save plots. Defaults to Path(".").
+            plot (bool): Whether to plot precision-recall curves. Defaults to False.
+            on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
+
+        Returns:
+            (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
         """
         """
+        stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()}  # to numpy
+        if len(stats) == 0:
+            return stats
         results = ap_per_class(
         results = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            save_dir=self.save_dir,
+            stats["tp"],
+            stats["conf"],
+            stats["pred_cls"],
+            stats["target_cls"],
+            plot=plot,
+            save_dir=save_dir,
             names=self.names,
             names=self.names,
             on_plot=on_plot,
             on_plot=on_plot,
         )[2:]
         )[2:]
         self.box.nc = len(self.names)
         self.box.nc = len(self.names)
         self.box.update(results)
         self.box.update(results)
+        self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
+        self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
+        return stats
+
+    def clear_stats(self):
+        """Clear the stored statistics."""
+        for v in self.stats.values():
+            v.clear()
 
 
     @property
     @property
     def keys(self) -> List[str]:
     def keys(self) -> List[str]:
@@ -1077,92 +1098,65 @@ class DetMetrics(SimpleClass, DataExportMixin):
         ]
         ]
 
 
 
 
-class SegmentMetrics(SimpleClass, DataExportMixin):
+class SegmentMetrics(DetMetrics):
     """
     """
     Calculate and aggregate detection and segmentation metrics over a given set of classes.
     Calculate and aggregate detection and segmentation metrics over a given set of classes.
 
 
     Attributes:
     Attributes:
-        save_dir (Path): Path to the directory where the output plots should be saved.
-        plot (bool): Whether to save the detection and segmentation plots.
         names (Dict[int, str]): Dictionary of class names.
         names (Dict[int, str]): Dictionary of class names.
         box (Metric): An instance of the Metric class for storing detection results.
         box (Metric): An instance of the Metric class for storing detection results.
         seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
         seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'segment'.
         task (str): The task type, set to 'segment'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
     """
     """
 
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         """
         Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
         Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
 
 
         Args:
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
             names (Dict[int, str], optional): Dictionary of class names.
         """
         """
-        self.save_dir = save_dir
-        self.plot = plot
-        self.names = names
-        self.box = Metric()
+        DetMetrics.__init__(self, names)
         self.seg = Metric()
         self.seg = Metric()
-        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
         self.task = "segment"
         self.task = "segment"
+        self.stats["tp_m"] = []  # add additional stats for masks
 
 
-    def process(
-        self,
-        tp: np.ndarray,
-        tp_m: np.ndarray,
-        conf: np.ndarray,
-        pred_cls: np.ndarray,
-        target_cls: np.ndarray,
-        on_plot=None,
-    ):
+    def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
         """
         """
         Process the detection and segmentation metrics over the given set of predictions.
         Process the detection and segmentation metrics over the given set of predictions.
 
 
         Args:
         Args:
-            tp (np.ndarray): True positive array for boxes.
-            tp_m (np.ndarray): True positive array for masks.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
-            on_plot (callable, optional): Function to call after plots are generated.
+            save_dir (Path): Directory to save plots. Defaults to Path(".").
+            plot (bool): Whether to plot precision-recall curves. Defaults to False.
+            on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
+
+        Returns:
+            (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
         """
         """
+        stats = DetMetrics.process(self, on_plot=on_plot)  # process box stats
         results_mask = ap_per_class(
         results_mask = ap_per_class(
-            tp_m,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
+            stats["tp_m"],
+            stats["conf"],
+            stats["pred_cls"],
+            stats["target_cls"],
+            plot=plot,
             on_plot=on_plot,
             on_plot=on_plot,
-            save_dir=self.save_dir,
+            save_dir=save_dir,
             names=self.names,
             names=self.names,
             prefix="Mask",
             prefix="Mask",
         )[2:]
         )[2:]
         self.seg.nc = len(self.names)
         self.seg.nc = len(self.names)
         self.seg.update(results_mask)
         self.seg.update(results_mask)
-        results_box = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            on_plot=on_plot,
-            save_dir=self.save_dir,
-            names=self.names,
-            prefix="Box",
-        )[2:]
-        self.box.nc = len(self.names)
-        self.box.update(results_box)
+        return stats
 
 
     @property
     @property
     def keys(self) -> List[str]:
     def keys(self) -> List[str]:
         """Return a list of keys for accessing metrics."""
         """Return a list of keys for accessing metrics."""
-        return [
-            "metrics/precision(B)",
-            "metrics/recall(B)",
-            "metrics/mAP50(B)",
-            "metrics/mAP50-95(B)",
+        return DetMetrics.keys.fget(self) + [
             "metrics/precision(M)",
             "metrics/precision(M)",
             "metrics/recall(M)",
             "metrics/recall(M)",
             "metrics/mAP50(M)",
             "metrics/mAP50(M)",
@@ -1171,40 +1165,26 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
 
 
     def mean_results(self) -> List[float]:
     def mean_results(self) -> List[float]:
         """Return the mean metrics for bounding box and segmentation results."""
         """Return the mean metrics for bounding box and segmentation results."""
-        return self.box.mean_results() + self.seg.mean_results()
+        return DetMetrics.mean_results(self) + self.seg.mean_results()
 
 
     def class_result(self, i: int) -> List[float]:
     def class_result(self, i: int) -> List[float]:
         """Return classification results for a specified class index."""
         """Return classification results for a specified class index."""
-        return self.box.class_result(i) + self.seg.class_result(i)
+        return DetMetrics.class_result(self, i) + self.seg.class_result(i)
 
 
     @property
     @property
     def maps(self) -> np.ndarray:
     def maps(self) -> np.ndarray:
         """Return mAP scores for object detection and semantic segmentation models."""
         """Return mAP scores for object detection and semantic segmentation models."""
-        return self.box.maps + self.seg.maps
+        return DetMetrics.maps.fget(self) + self.seg.maps
 
 
     @property
     @property
     def fitness(self) -> float:
     def fitness(self) -> float:
         """Return the fitness score for both segmentation and bounding box models."""
         """Return the fitness score for both segmentation and bounding box models."""
-        return self.seg.fitness() + self.box.fitness()
-
-    @property
-    def ap_class_index(self) -> List:
-        """Return the class indices (boxes and masks have the same ap_class_index)."""
-        return self.box.ap_class_index
-
-    @property
-    def results_dict(self) -> Dict[str, float]:
-        """Return results of object detection model for evaluation."""
-        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
+        return self.seg.fitness() + DetMetrics.fitness.fget(self)
 
 
     @property
     @property
     def curves(self) -> List[str]:
     def curves(self) -> List[str]:
         """Return a list of curves for accessing specific metrics curves."""
         """Return a list of curves for accessing specific metrics curves."""
-        return [
-            "Precision-Recall(B)",
-            "F1-Confidence(B)",
-            "Precision-Confidence(B)",
-            "Recall-Confidence(B)",
+        return DetMetrics.curves.fget(self) + [
             "Precision-Recall(M)",
             "Precision-Recall(M)",
             "F1-Confidence(M)",
             "F1-Confidence(M)",
             "Precision-Confidence(M)",
             "Precision-Confidence(M)",
@@ -1214,7 +1194,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
     @property
     @property
     def curves_results(self) -> List[List]:
     def curves_results(self) -> List[List]:
         """Return dictionary of computed performance metrics and statistics."""
         """Return dictionary of computed performance metrics and statistics."""
-        return self.box.curves_results + self.seg.curves_results
+        return DetMetrics.curves_results.fget(self) + self.seg.curves_results
 
 
     def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
     def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
         """
         """
@@ -1234,43 +1214,34 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
             >>> print(seg_summary)
             >>> print(seg_summary)
         """
         """
         scalars = {
         scalars = {
-            "box-map": round(self.box.map, decimals),
-            "box-map50": round(self.box.map50, decimals),
-            "box-map75": round(self.box.map75, decimals),
             "mask-map": round(self.seg.map, decimals),
             "mask-map": round(self.seg.map, decimals),
             "mask-map50": round(self.seg.map50, decimals),
             "mask-map50": round(self.seg.map50, decimals),
             "mask-map75": round(self.seg.map75, decimals),
             "mask-map75": round(self.seg.map75, decimals),
         }
         }
         per_class = {
         per_class = {
-            "box-p": self.box.p,
-            "box-r": self.box.r,
-            "box-f1": self.box.f1,
             "mask-p": self.seg.p,
             "mask-p": self.seg.p,
             "mask-r": self.seg.r,
             "mask-r": self.seg.r,
             "mask-f1": self.seg.f1,
             "mask-f1": self.seg.f1,
         }
         }
-        return [
-            {
-                "class_name": self.names[self.ap_class_index[i]],
-                **{k: round(v[i], decimals) for k, v in per_class.items()},
-                **scalars,
-            }
-            for i in range(len(per_class["box-p"]))
-        ]
+        summary = DetMetrics.summary(self, normalize, decimals)  # get box summary
+        for i, s in enumerate(summary):
+            s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
+        return summary
 
 
 
 
-class PoseMetrics(SegmentMetrics):
+class PoseMetrics(DetMetrics):
     """
     """
     Calculate and aggregate detection and pose metrics over a given set of classes.
     Calculate and aggregate detection and pose metrics over a given set of classes.
 
 
     Attributes:
     Attributes:
-        save_dir (Path): Path to the directory where the output plots should be saved.
-        plot (bool): Whether to save the detection and pose plots.
         names (Dict[int, str]): Dictionary of class names.
         names (Dict[int, str]): Dictionary of class names.
         pose (Metric): An instance of the Metric class to calculate pose metrics.
         pose (Metric): An instance of the Metric class to calculate pose metrics.
         box (Metric): An instance of the Metric class for storing detection results.
         box (Metric): An instance of the Metric class for storing detection results.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'pose'.
         task (str): The task type, set to 'pose'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
 
 
     Methods:
     Methods:
         process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
         process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
@@ -1282,79 +1253,50 @@ class PoseMetrics(SegmentMetrics):
         results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
         results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
     """
     """
 
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         """
         Initialize the PoseMetrics class with directory path, class names, and plotting options.
         Initialize the PoseMetrics class with directory path, class names, and plotting options.
 
 
         Args:
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
             names (Dict[int, str], optional): Dictionary of class names.
         """
         """
-        super().__init__(save_dir, plot, names)
-        self.save_dir = save_dir
-        self.plot = plot
-        self.names = names
-        self.box = Metric()
+        super().__init__(names)
         self.pose = Metric()
         self.pose = Metric()
-        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
         self.task = "pose"
         self.task = "pose"
+        self.stats["tp_p"] = []  # add additional stats for pose
 
 
-    def process(
-        self,
-        tp: np.ndarray,
-        tp_p: np.ndarray,
-        conf: np.ndarray,
-        pred_cls: np.ndarray,
-        target_cls: np.ndarray,
-        on_plot=None,
-    ):
+    def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
         """
         """
         Process the detection and pose metrics over the given set of predictions.
         Process the detection and pose metrics over the given set of predictions.
 
 
         Args:
         Args:
-            tp (np.ndarray): True positive array for boxes.
-            tp_p (np.ndarray): True positive array for keypoints.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
+            save_dir (Path): Directory to save plots. Defaults to Path(".").
+            plot (bool): Whether to plot precision-recall curves. Defaults to False.
             on_plot (callable, optional): Function to call after plots are generated.
             on_plot (callable, optional): Function to call after plots are generated.
+
+        Returns:
+            (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
         """
         """
+        stats = DetMetrics.process(self, on_plot=on_plot)  # process box stats
         results_pose = ap_per_class(
         results_pose = ap_per_class(
-            tp_p,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
+            stats["tp_p"],
+            stats["conf"],
+            stats["pred_cls"],
+            stats["target_cls"],
+            plot=plot,
             on_plot=on_plot,
             on_plot=on_plot,
-            save_dir=self.save_dir,
+            save_dir=save_dir,
             names=self.names,
             names=self.names,
             prefix="Pose",
             prefix="Pose",
         )[2:]
         )[2:]
         self.pose.nc = len(self.names)
         self.pose.nc = len(self.names)
         self.pose.update(results_pose)
         self.pose.update(results_pose)
-        results_box = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            on_plot=on_plot,
-            save_dir=self.save_dir,
-            names=self.names,
-            prefix="Box",
-        )[2:]
-        self.box.nc = len(self.names)
-        self.box.update(results_box)
+        return stats
 
 
     @property
     @property
     def keys(self) -> List[str]:
     def keys(self) -> List[str]:
         """Return list of evaluation metric keys."""
         """Return list of evaluation metric keys."""
-        return [
-            "metrics/precision(B)",
-            "metrics/recall(B)",
-            "metrics/mAP50(B)",
-            "metrics/mAP50-95(B)",
+        return DetMetrics.keys.fget(self) + [
             "metrics/precision(P)",
             "metrics/precision(P)",
             "metrics/recall(P)",
             "metrics/recall(P)",
             "metrics/mAP50(P)",
             "metrics/mAP50(P)",
@@ -1363,26 +1305,26 @@ class PoseMetrics(SegmentMetrics):
 
 
     def mean_results(self) -> List[float]:
     def mean_results(self) -> List[float]:
         """Return the mean results of box and pose."""
         """Return the mean results of box and pose."""
-        return self.box.mean_results() + self.pose.mean_results()
+        return DetMetrics.mean_results(self) + self.pose.mean_results()
 
 
     def class_result(self, i: int) -> List[float]:
     def class_result(self, i: int) -> List[float]:
         """Return the class-wise detection results for a specific class i."""
         """Return the class-wise detection results for a specific class i."""
-        return self.box.class_result(i) + self.pose.class_result(i)
+        return DetMetrics.class_result(self, i) + self.pose.class_result(i)
 
 
     @property
     @property
     def maps(self) -> np.ndarray:
     def maps(self) -> np.ndarray:
         """Return the mean average precision (mAP) per class for both box and pose detections."""
         """Return the mean average precision (mAP) per class for both box and pose detections."""
-        return self.box.maps + self.pose.maps
+        return DetMetrics.maps.fget(self) + self.pose.maps
 
 
     @property
     @property
     def fitness(self) -> float:
     def fitness(self) -> float:
         """Return combined fitness score for pose and box detection."""
         """Return combined fitness score for pose and box detection."""
-        return self.pose.fitness() + self.box.fitness()
+        return self.pose.fitness() + DetMetrics.fitness.fget(self)
 
 
     @property
     @property
     def curves(self) -> List[str]:
     def curves(self) -> List[str]:
         """Return a list of curves for accessing specific metrics curves."""
         """Return a list of curves for accessing specific metrics curves."""
-        return [
+        return DetMetrics.curves.fget(self) + [
             "Precision-Recall(B)",
             "Precision-Recall(B)",
             "F1-Confidence(B)",
             "F1-Confidence(B)",
             "Precision-Confidence(B)",
             "Precision-Confidence(B)",
@@ -1396,7 +1338,7 @@ class PoseMetrics(SegmentMetrics):
     @property
     @property
     def curves_results(self) -> List[List]:
     def curves_results(self) -> List[List]:
         """Return dictionary of computed performance metrics and statistics."""
         """Return dictionary of computed performance metrics and statistics."""
-        return self.box.curves_results + self.pose.curves_results
+        return DetMetrics.curves_results.fget(self) + self.pose.curves_results
 
 
     def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
     def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
         """
         """
@@ -1416,29 +1358,19 @@ class PoseMetrics(SegmentMetrics):
             >>> print(pose_summary)
             >>> print(pose_summary)
         """
         """
         scalars = {
         scalars = {
-            "box-map": round(self.box.map, decimals),
-            "box-map50": round(self.box.map50, decimals),
-            "box-map75": round(self.box.map75, decimals),
             "pose-map": round(self.pose.map, decimals),
             "pose-map": round(self.pose.map, decimals),
             "pose-map50": round(self.pose.map50, decimals),
             "pose-map50": round(self.pose.map50, decimals),
             "pose-map75": round(self.pose.map75, decimals),
             "pose-map75": round(self.pose.map75, decimals),
         }
         }
         per_class = {
         per_class = {
-            "box-p": self.box.p,
-            "box-r": self.box.r,
-            "box-f1": self.box.f1,
             "pose-p": self.pose.p,
             "pose-p": self.pose.p,
             "pose-r": self.pose.r,
             "pose-r": self.pose.r,
             "pose-f1": self.pose.f1,
             "pose-f1": self.pose.f1,
         }
         }
-        return [
-            {
-                "class_name": self.names[self.ap_class_index[i]],
-                **{k: round(v[i], decimals) for k, v in per_class.items()},
-                **scalars,
-            }
-            for i in range(len(per_class["box-p"]))
-        ]
+        summary = DetMetrics.summary(self, normalize, decimals)  # get box summary
+        for i, s in enumerate(summary):
+            s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
+        return summary
 
 
 
 
 class ClassifyMetrics(SimpleClass, DataExportMixin):
 class ClassifyMetrics(SimpleClass, DataExportMixin):
@@ -1516,133 +1448,30 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
         return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
         return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
 
 
 
 
-class OBBMetrics(SimpleClass, DataExportMixin):
+class OBBMetrics(DetMetrics):
     """
     """
     Metrics for evaluating oriented bounding box (OBB) detection.
     Metrics for evaluating oriented bounding box (OBB) detection.
 
 
     Attributes:
     Attributes:
-        save_dir (Path): Path to the directory where the output plots should be saved.
-        plot (bool): Whether to save the detection plots.
         names (Dict[int, str]): Dictionary of class names.
         names (Dict[int, str]): Dictionary of class names.
         box (Metric): An instance of the Metric class for storing detection results.
         box (Metric): An instance of the Metric class for storing detection results.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
         task (str): The task type, set to 'obb'.
         task (str): The task type, set to 'obb'.
+        stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
+        nt_per_class: Number of targets per class.
+        nt_per_image: Number of targets per image.
 
 
     References:
     References:
         https://arxiv.org/pdf/2106.06072.pdf
         https://arxiv.org/pdf/2106.06072.pdf
     """
     """
 
 
-    def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
+    def __init__(self, names: Dict[int, str] = {}) -> None:
         """
         """
         Initialize an OBBMetrics instance with directory, plotting, and class names.
         Initialize an OBBMetrics instance with directory, plotting, and class names.
 
 
         Args:
         Args:
-            save_dir (Path, optional): Directory to save plots.
-            plot (bool, optional): Whether to plot precision-recall curves.
             names (Dict[int, str], optional): Dictionary of class names.
             names (Dict[int, str], optional): Dictionary of class names.
         """
         """
-        self.save_dir = save_dir
-        self.plot = plot
-        self.names = names
-        self.box = Metric()
-        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+        DetMetrics.__init__(self, names)
+        # TODO: probably remove task as well
         self.task = "obb"
         self.task = "obb"
-
-    def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
-        """
-        Process predicted results for object detection and update metrics.
-
-        Args:
-            tp (np.ndarray): True positive array.
-            conf (np.ndarray): Confidence array.
-            pred_cls (np.ndarray): Predicted class indices array.
-            target_cls (np.ndarray): Target class indices array.
-            on_plot (callable, optional): Function to call after plots are generated.
-        """
-        results = ap_per_class(
-            tp,
-            conf,
-            pred_cls,
-            target_cls,
-            plot=self.plot,
-            save_dir=self.save_dir,
-            names=self.names,
-            on_plot=on_plot,
-        )[2:]
-        self.box.nc = len(self.names)
-        self.box.update(results)
-
-    @property
-    def keys(self) -> List[str]:
-        """Return a list of keys for accessing specific metrics."""
-        return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
-
-    def mean_results(self) -> List[float]:
-        """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
-        return self.box.mean_results()
-
-    def class_result(self, i: int) -> Tuple[float, float, float, float]:
-        """Return the result of evaluating the performance of an object detection model on a specific class."""
-        return self.box.class_result(i)
-
-    @property
-    def maps(self) -> np.ndarray:
-        """Return mean Average Precision (mAP) scores per class."""
-        return self.box.maps
-
-    @property
-    def fitness(self) -> float:
-        """Return the fitness of box object."""
-        return self.box.fitness()
-
-    @property
-    def ap_class_index(self) -> List:
-        """Return the average precision index per class."""
-        return self.box.ap_class_index
-
-    @property
-    def results_dict(self) -> Dict[str, float]:
-        """Return dictionary of computed performance metrics and statistics."""
-        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
-
-    @property
-    def curves(self) -> List:
-        """Return a list of curves for accessing specific metrics curves."""
-        return []
-
-    @property
-    def curves_results(self) -> List:
-        """Return a list of curves for accessing specific metrics curves."""
-        return []
-
-    def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
-        """
-        Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
-        scalar metrics (mAP, mAP50, mAP75) along with precision, recall, and F1-score for each class.
-
-        Args:
-            normalize (bool): For OBB metrics, everything is normalized  by default [0-1].
-            decimals (int): Number of decimal places to round the metrics values to.
-
-        Returns:
-            (List[Dict[str, Union[str, float]]]): A list of dictionaries, each representing one class with detection metrics.
-
-        Examples:
-            >>> results = model.val(data="dota8.yaml")
-            >>> detection_summary = results.summary(decimals=4)
-            >>> print(detection_summary)
-        """
-        scalars = {
-            "box-map": round(self.box.map, decimals),
-            "box-map50": round(self.box.map50, decimals),
-            "box-map75": round(self.box.map75, decimals),
-        }
-        per_class = {"box-p": self.box.p, "box-r": self.box.r, "box-f1": self.box.f1}
-        return [
-            {
-                "class_name": self.names[self.ap_class_index[i]],
-                **{k: round(v[i], decimals) for k, v in per_class.items()},
-                **scalars,
-            }
-            for i in range(len(per_class["box-p"]))
-        ]
Discard
@@ -255,7 +255,7 @@ def non_max_suppression(
 
 
     bs = prediction.shape[0]  # batch size (BCN, i.e. 1,84,6300)
     bs = prediction.shape[0]  # batch size (BCN, i.e. 1,84,6300)
     nc = nc or (prediction.shape[1] - 4)  # number of classes
     nc = nc or (prediction.shape[1] - 4)  # number of classes
-    nm = prediction.shape[1] - nc - 4  # number of masks
+    extra = prediction.shape[1] - nc - 4  # number of extra info
     mi = 4 + nc  # mask start index
     mi = 4 + nc  # mask start index
     xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates
     xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates
     xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None]  # to track idxs
     xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None]  # to track idxs
@@ -273,7 +273,7 @@ def non_max_suppression(
             prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1)  # xywh to xyxy
             prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1)  # xywh to xyxy
 
 
     t = time.time()
     t = time.time()
-    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
+    output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
     keepi = [torch.zeros((0, 1), device=prediction.device)] * bs  # to store the kept idxs
     keepi = [torch.zeros((0, 1), device=prediction.device)] * bs  # to store the kept idxs
     for xi, (x, xk) in enumerate(zip(prediction, xinds)):  # image index, (preds, preds indices)
     for xi, (x, xk) in enumerate(zip(prediction, xinds)):  # image index, (preds, preds indices)
         # Apply constraints
         # Apply constraints
@@ -284,7 +284,7 @@ def non_max_suppression(
         # Cat apriori labels if autolabelling
         # Cat apriori labels if autolabelling
         if labels and len(labels[xi]) and not rotated:
         if labels and len(labels[xi]) and not rotated:
             lb = labels[xi]
             lb = labels[xi]
-            v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
+            v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
             v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
             v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
             v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
             v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
             x = torch.cat((x, v), 0)
             x = torch.cat((x, v), 0)
@@ -294,7 +294,7 @@ def non_max_suppression(
             continue
             continue
 
 
         # Detections matrix nx6 (xyxy, conf, cls)
         # Detections matrix nx6 (xyxy, conf, cls)
-        box, cls, mask = x.split((4, nc, nm), 1)
+        box, cls, mask = x.split((4, nc, extra), 1)
 
 
         if multi_label:
         if multi_label:
             i, j = torch.where(cls > conf_thres)
             i, j = torch.where(cls > conf_thres)
Discard
@@ -3,7 +3,7 @@
 import math
 import math
 import warnings
 import warnings
 from pathlib import Path
 from pathlib import Path
-from typing import Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
 
 
 import cv2
 import cv2
 import numpy as np
 import numpy as np
@@ -678,13 +678,8 @@ def save_one_box(
 
 
 @threaded
 @threaded
 def plot_images(
 def plot_images(
-    images: Union[torch.Tensor, np.ndarray],
-    batch_idx: Union[torch.Tensor, np.ndarray],
-    cls: Union[torch.Tensor, np.ndarray],
-    bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
-    confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
-    masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
-    kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
+    labels: Dict[str, Any],
+    images: Union[torch.Tensor, np.ndarray] = np.zeros((0, 3, 640, 640), dtype=np.float32),
     paths: Optional[List[str]] = None,
     paths: Optional[List[str]] = None,
     fname: str = "images.jpg",
     fname: str = "images.jpg",
     names: Optional[Dict[int, str]] = None,
     names: Optional[Dict[int, str]] = None,
@@ -698,21 +693,16 @@ def plot_images(
     Plot image grid with labels, bounding boxes, masks, and keypoints.
     Plot image grid with labels, bounding boxes, masks, and keypoints.
 
 
     Args:
     Args:
-        images: Batch of images to plot. Shape: (batch_size, channels, height, width).
-        batch_idx: Batch indices for each detection. Shape: (num_detections,).
-        cls: Class labels for each detection. Shape: (num_detections,).
-        bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
-        confs: Confidence scores for each detection. Shape: (num_detections,).
-        masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
-        kpts: Keypoints for each detection. Shape: (num_detections, 51).
-        paths: List of file paths for each image in the batch.
-        fname: Output filename for the plotted image grid.
-        names: Dictionary mapping class indices to class names.
-        on_plot: Optional callback function to be called after saving the plot.
-        max_size: Maximum size of the output image grid.
-        max_subplots: Maximum number of subplots in the image grid.
-        save: Whether to save the plotted image grid to a file.
-        conf_thres: Confidence threshold for displaying detections.
+        labels (Dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.
+        images (Union[torch.Tensor, np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
+        paths (Optional[List[str]]): List of file paths for each image in the batch.
+        fname (str): Output filename for the plotted image grid.
+        names (Optional[Dict[int, str]]): Dictionary mapping class indices to class names.
+        on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.
+        max_size (int): Maximum size of the output image grid.
+        max_subplots (int): Maximum number of subplots in the image grid.
+        save (bool): Whether to save the plotted image grid to a file.
+        conf_thres (float): Confidence threshold for displaying detections.
 
 
     Returns:
     Returns:
         (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
         (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
@@ -721,18 +711,24 @@ def plot_images(
         This function supports both tensor and numpy array inputs. It will automatically
         This function supports both tensor and numpy array inputs. It will automatically
         convert tensor inputs to numpy arrays for processing.
         convert tensor inputs to numpy arrays for processing.
     """
     """
-    if isinstance(images, torch.Tensor):
+    for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
+        if k not in labels:
+            continue
+        if k == "cls" and labels[k].ndim == 2:
+            labels[k] = labels[k].squeeze(1)  # squeeze if shape is (n, 1)
+        if isinstance(labels[k], torch.Tensor):
+            labels[k] = labels[k].cpu().numpy()
+
+    cls = labels.get("cls", np.zeros(0, dtype=np.int64))
+    batch_idx = labels.get("batch_idx", np.zeros(cls.shape, dtype=np.int64))
+    bboxes = labels.get("bboxes", np.zeros(0, dtype=np.float32))
+    confs = labels.get("conf", None)
+    masks = labels.get("masks", np.zeros(0, dtype=np.uint8))
+    kpts = labels.get("keypoints", np.zeros(0, dtype=np.float32))
+    images = labels.get("img", images)  # default to input images
+
+    if len(images) and isinstance(images, torch.Tensor):
         images = images.cpu().float().numpy()
         images = images.cpu().float().numpy()
-    if isinstance(cls, torch.Tensor):
-        cls = cls.cpu().numpy()
-    if isinstance(bboxes, torch.Tensor):
-        bboxes = bboxes.cpu().numpy()
-    if isinstance(masks, torch.Tensor):
-        masks = masks.cpu().numpy().astype(int)
-    if isinstance(kpts, torch.Tensor):
-        kpts = kpts.cpu().numpy()
-    if isinstance(batch_idx, torch.Tensor):
-        batch_idx = batch_idx.cpu().numpy()
     if images.shape[1] > 3:
     if images.shape[1] > 3:
         images = images[:, :3]  # crop multispectral images to first 3 channels
         images = images[:, :3]  # crop multispectral images to first 3 channels
 
 
@@ -781,6 +777,7 @@ def plot_images(
                 boxes[..., 0] += x
                 boxes[..., 0] += x
                 boxes[..., 1] += y
                 boxes[..., 1] += y
                 is_obb = boxes.shape[-1] == 5  # xywhr
                 is_obb = boxes.shape[-1] == 5  # xywhr
+                # TODO: this transformation might be unnecessary
                 boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
                 boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
                 for j, box in enumerate(boxes.astype(np.int64).tolist()):
                 for j, box in enumerate(boxes.astype(np.int64).tolist()):
                     c = classes[j]
                     c = classes[j]
@@ -1004,28 +1001,6 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
     _save_one_file(csv_file.with_name("tune_fitness.png"))
     _save_one_file(csv_file.with_name("tune_fitness.png"))
 
 
 
 
-def output_to_target(output, max_det: int = 300):
-    """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
-    targets = []
-    for i, o in enumerate(output):
-        box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
-        j = torch.full((conf.shape[0], 1), i)
-        targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
-    targets = torch.cat(targets, 0).numpy()
-    return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
-
-
-def output_to_rotated_target(output, max_det: int = 300):
-    """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
-    targets = []
-    for i, o in enumerate(output):
-        box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
-        j = torch.full((conf.shape[0], 1), i)
-        targets.append(torch.cat((j, cls, box, angle, conf), 1))
-    targets = torch.cat(targets, 0).numpy()
-    return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
-
-
 def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
 def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
     """
     """
     Visualize feature maps of a given model module during inference.
     Visualize feature maps of a given model module during inference.
Discard