|
@@ -4,7 +4,7 @@
|
|
import math
|
|
import math
|
|
import warnings
|
|
import warnings
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Dict, List, Tuple, Union
|
|
|
|
|
|
+from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
@@ -316,28 +316,22 @@ class ConfusionMatrix(DataExportMixin):
|
|
Attributes:
|
|
Attributes:
|
|
task (str): The type of task, either 'detect' or 'classify'.
|
|
task (str): The type of task, either 'detect' or 'classify'.
|
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
|
- nc (int): The number of classes.
|
|
|
|
- conf (float): The confidence threshold for detections.
|
|
|
|
- iou_thres (float): The Intersection over Union threshold.
|
|
|
|
|
|
+ nc (int): The number of category.
|
|
|
|
+ names (List[str]): The names of the classes, used as labels on the plot.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, names: tuple = (), task: str = "detect"):
|
|
|
|
|
|
+ def __init__(self, names: List[str] = [], task: str = "detect"):
|
|
"""
|
|
"""
|
|
Initialize a ConfusionMatrix instance.
|
|
Initialize a ConfusionMatrix instance.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- nc (int): Number of classes.
|
|
|
|
- conf (float, optional): Confidence threshold for detections.
|
|
|
|
- iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
|
|
|
- names (tuple, optional): Names of classes, used as labels on the plot.
|
|
|
|
|
|
+ names (List[str], optional): Names of classes, used as labels on the plot.
|
|
task (str, optional): Type of task, either 'detect' or 'classify'.
|
|
task (str, optional): Type of task, either 'detect' or 'classify'.
|
|
"""
|
|
"""
|
|
self.task = task
|
|
self.task = task
|
|
- self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
|
|
|
- self.nc = nc # number of classes
|
|
|
|
- self.names = list(names) # name of classes
|
|
|
|
- self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
|
|
|
- self.iou_thres = iou_thres
|
|
|
|
|
|
+ self.nc = len(names) # number of classes
|
|
|
|
+ self.matrix = np.zeros((self.nc + 1, self.nc + 1)) if self.task == "detect" else np.zeros((self.nc, self.nc))
|
|
|
|
+ self.names = names # name of classes
|
|
|
|
|
|
def process_cls_preds(self, preds, targets):
|
|
def process_cls_preds(self, preds, targets):
|
|
"""
|
|
"""
|
|
@@ -351,41 +345,45 @@ class ConfusionMatrix(DataExportMixin):
|
|
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
|
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
|
self.matrix[p][t] += 1
|
|
self.matrix[p][t] += 1
|
|
|
|
|
|
- def process_batch(self, detections, gt_bboxes, gt_cls):
|
|
|
|
|
|
+ def process_batch(
|
|
|
|
+ self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
|
|
|
|
+ ) -> None:
|
|
"""
|
|
"""
|
|
Update confusion matrix for object detection task.
|
|
Update confusion matrix for object detection task.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
|
|
|
|
- Each row should contain (x1, y1, x2, y2, conf, class)
|
|
|
|
- or with an additional element `angle` when it's obb.
|
|
|
|
- gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
|
|
|
|
- gt_cls (Array[M]): The class labels.
|
|
|
|
|
|
+ detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
|
|
|
|
+ Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
|
|
|
|
+ Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
|
|
|
|
+ batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
|
|
|
|
+ 'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
|
|
|
+ conf (float, optional): Confidence threshold for detections.
|
|
|
|
+ iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
|
"""
|
|
"""
|
|
|
|
+ conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
|
|
|
+ gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
|
|
|
|
+ no_pred = len(detections["cls"]) == 0
|
|
if gt_cls.shape[0] == 0: # Check if labels is empty
|
|
if gt_cls.shape[0] == 0: # Check if labels is empty
|
|
- if detections is not None:
|
|
|
|
- detections = detections[detections[:, 4] > self.conf]
|
|
|
|
- detection_classes = detections[:, 5].int().tolist()
|
|
|
|
|
|
+ if not no_pred:
|
|
|
|
+ detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
|
|
|
|
+ detection_classes = detections["cls"].int().tolist()
|
|
for dc in detection_classes:
|
|
for dc in detection_classes:
|
|
self.matrix[dc, self.nc] += 1 # false positives
|
|
self.matrix[dc, self.nc] += 1 # false positives
|
|
return
|
|
return
|
|
- if detections is None:
|
|
|
|
|
|
+ if no_pred:
|
|
gt_classes = gt_cls.int().tolist()
|
|
gt_classes = gt_cls.int().tolist()
|
|
for gc in gt_classes:
|
|
for gc in gt_classes:
|
|
self.matrix[self.nc, gc] += 1 # background FN
|
|
self.matrix[self.nc, gc] += 1 # background FN
|
|
return
|
|
return
|
|
|
|
|
|
- detections = detections[detections[:, 4] > self.conf]
|
|
|
|
|
|
+ detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
|
|
gt_classes = gt_cls.int().tolist()
|
|
gt_classes = gt_cls.int().tolist()
|
|
- detection_classes = detections[:, 5].int().tolist()
|
|
|
|
- is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
|
|
|
|
- iou = (
|
|
|
|
- batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
|
|
|
|
- if is_obb
|
|
|
|
- else box_iou(gt_bboxes, detections[:, :4])
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- x = torch.where(iou > self.iou_thres)
|
|
|
|
|
|
+ detection_classes = detections["cls"].int().tolist()
|
|
|
|
+ bboxes = detections["bboxes"]
|
|
|
|
+ is_obb = bboxes.shape[1] == 5 # check if detections contains angle for OBB
|
|
|
|
+ iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
|
|
|
|
+
|
|
|
|
+ x = torch.where(iou > iou_thres)
|
|
if x[0].shape[0]:
|
|
if x[0].shape[0]:
|
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
|
if x[0].shape[0] > 1:
|
|
if x[0].shape[0] > 1:
|
|
@@ -949,53 +947,76 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
|
|
|
|
Attributes:
|
|
Attributes:
|
|
- save_dir (Path): A path to the directory where the output plots will be saved.
|
|
|
|
- plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
|
|
|
|
names (Dict[int, str]): A dictionary of class names.
|
|
names (Dict[int, str]): A dictionary of class names.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
task (str): The task type, set to 'detect'.
|
|
task (str): The task type, set to 'detect'.
|
|
|
|
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
|
|
|
+ nt_per_class: Number of targets per class.
|
|
|
|
+ nt_per_image: Number of targets per image.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
|
|
|
|
|
|
+ def __init__(self, names: Dict[int, str] = {}) -> None:
|
|
"""
|
|
"""
|
|
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- save_dir (Path, optional): Directory to save plots.
|
|
|
|
- plot (bool, optional): Whether to plot precision-recall curves.
|
|
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
"""
|
|
"""
|
|
- self.save_dir = save_dir
|
|
|
|
- self.plot = plot
|
|
|
|
self.names = names
|
|
self.names = names
|
|
self.box = Metric()
|
|
self.box = Metric()
|
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
self.task = "detect"
|
|
self.task = "detect"
|
|
|
|
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
|
|
|
+ self.nt_per_class = None
|
|
|
|
+ self.nt_per_image = None
|
|
|
|
|
|
- def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
|
|
|
|
|
|
+ def update_stats(self, stat: Dict[str, Any]) -> None:
|
|
|
|
+ """
|
|
|
|
+ Update statistics by appending new values to existing stat collections.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ stat (Dict[str, any]): Dictionary containing new statistical values to append.
|
|
|
|
+ Keys should match existing keys in self.stats.
|
|
|
|
+ """
|
|
|
|
+ for k in self.stats.keys():
|
|
|
|
+ self.stats[k].append(stat[k])
|
|
|
|
+
|
|
|
|
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
|
"""
|
|
"""
|
|
Process predicted results for object detection and update metrics.
|
|
Process predicted results for object detection and update metrics.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- tp (np.ndarray): True positive array.
|
|
|
|
- conf (np.ndarray): Confidence array.
|
|
|
|
- pred_cls (np.ndarray): Predicted class indices array.
|
|
|
|
- target_cls (np.ndarray): Target class indices array.
|
|
|
|
- on_plot (callable, optional): Function to call after plots are generated.
|
|
|
|
|
|
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
|
|
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
|
|
|
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
|
"""
|
|
"""
|
|
|
|
+ stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
|
|
|
|
+ if len(stats) == 0:
|
|
|
|
+ return stats
|
|
results = ap_per_class(
|
|
results = ap_per_class(
|
|
- tp,
|
|
|
|
- conf,
|
|
|
|
- pred_cls,
|
|
|
|
- target_cls,
|
|
|
|
- plot=self.plot,
|
|
|
|
- save_dir=self.save_dir,
|
|
|
|
|
|
+ stats["tp"],
|
|
|
|
+ stats["conf"],
|
|
|
|
+ stats["pred_cls"],
|
|
|
|
+ stats["target_cls"],
|
|
|
|
+ plot=plot,
|
|
|
|
+ save_dir=save_dir,
|
|
names=self.names,
|
|
names=self.names,
|
|
on_plot=on_plot,
|
|
on_plot=on_plot,
|
|
)[2:]
|
|
)[2:]
|
|
self.box.nc = len(self.names)
|
|
self.box.nc = len(self.names)
|
|
self.box.update(results)
|
|
self.box.update(results)
|
|
|
|
+ self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
|
|
|
|
+ self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
|
|
|
|
+ return stats
|
|
|
|
+
|
|
|
|
+ def clear_stats(self):
|
|
|
|
+ """Clear the stored statistics."""
|
|
|
|
+ for v in self.stats.values():
|
|
|
|
+ v.clear()
|
|
|
|
|
|
@property
|
|
@property
|
|
def keys(self) -> List[str]:
|
|
def keys(self) -> List[str]:
|
|
@@ -1077,92 +1098,65 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
]
|
|
]
|
|
|
|
|
|
|
|
|
|
-class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
|
|
|
|
+class SegmentMetrics(DetMetrics):
|
|
"""
|
|
"""
|
|
Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
|
Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
|
|
|
|
|
Attributes:
|
|
Attributes:
|
|
- save_dir (Path): Path to the directory where the output plots should be saved.
|
|
|
|
- plot (bool): Whether to save the detection and segmentation plots.
|
|
|
|
names (Dict[int, str]): Dictionary of class names.
|
|
names (Dict[int, str]): Dictionary of class names.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
task (str): The task type, set to 'segment'.
|
|
task (str): The task type, set to 'segment'.
|
|
|
|
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
|
|
|
+ nt_per_class: Number of targets per class.
|
|
|
|
+ nt_per_image: Number of targets per image.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
|
|
|
|
|
|
+ def __init__(self, names: Dict[int, str] = {}) -> None:
|
|
"""
|
|
"""
|
|
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- save_dir (Path, optional): Directory to save plots.
|
|
|
|
- plot (bool, optional): Whether to plot precision-recall curves.
|
|
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
"""
|
|
"""
|
|
- self.save_dir = save_dir
|
|
|
|
- self.plot = plot
|
|
|
|
- self.names = names
|
|
|
|
- self.box = Metric()
|
|
|
|
|
|
+ DetMetrics.__init__(self, names)
|
|
self.seg = Metric()
|
|
self.seg = Metric()
|
|
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
|
|
self.task = "segment"
|
|
self.task = "segment"
|
|
|
|
+ self.stats["tp_m"] = [] # add additional stats for masks
|
|
|
|
|
|
- def process(
|
|
|
|
- self,
|
|
|
|
- tp: np.ndarray,
|
|
|
|
- tp_m: np.ndarray,
|
|
|
|
- conf: np.ndarray,
|
|
|
|
- pred_cls: np.ndarray,
|
|
|
|
- target_cls: np.ndarray,
|
|
|
|
- on_plot=None,
|
|
|
|
- ):
|
|
|
|
|
|
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
|
"""
|
|
"""
|
|
Process the detection and segmentation metrics over the given set of predictions.
|
|
Process the detection and segmentation metrics over the given set of predictions.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- tp (np.ndarray): True positive array for boxes.
|
|
|
|
- tp_m (np.ndarray): True positive array for masks.
|
|
|
|
- conf (np.ndarray): Confidence array.
|
|
|
|
- pred_cls (np.ndarray): Predicted class indices array.
|
|
|
|
- target_cls (np.ndarray): Target class indices array.
|
|
|
|
- on_plot (callable, optional): Function to call after plots are generated.
|
|
|
|
|
|
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
|
|
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
|
|
|
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
|
"""
|
|
"""
|
|
|
|
+ stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
|
|
results_mask = ap_per_class(
|
|
results_mask = ap_per_class(
|
|
- tp_m,
|
|
|
|
- conf,
|
|
|
|
- pred_cls,
|
|
|
|
- target_cls,
|
|
|
|
- plot=self.plot,
|
|
|
|
|
|
+ stats["tp_m"],
|
|
|
|
+ stats["conf"],
|
|
|
|
+ stats["pred_cls"],
|
|
|
|
+ stats["target_cls"],
|
|
|
|
+ plot=plot,
|
|
on_plot=on_plot,
|
|
on_plot=on_plot,
|
|
- save_dir=self.save_dir,
|
|
|
|
|
|
+ save_dir=save_dir,
|
|
names=self.names,
|
|
names=self.names,
|
|
prefix="Mask",
|
|
prefix="Mask",
|
|
)[2:]
|
|
)[2:]
|
|
self.seg.nc = len(self.names)
|
|
self.seg.nc = len(self.names)
|
|
self.seg.update(results_mask)
|
|
self.seg.update(results_mask)
|
|
- results_box = ap_per_class(
|
|
|
|
- tp,
|
|
|
|
- conf,
|
|
|
|
- pred_cls,
|
|
|
|
- target_cls,
|
|
|
|
- plot=self.plot,
|
|
|
|
- on_plot=on_plot,
|
|
|
|
- save_dir=self.save_dir,
|
|
|
|
- names=self.names,
|
|
|
|
- prefix="Box",
|
|
|
|
- )[2:]
|
|
|
|
- self.box.nc = len(self.names)
|
|
|
|
- self.box.update(results_box)
|
|
|
|
|
|
+ return stats
|
|
|
|
|
|
@property
|
|
@property
|
|
def keys(self) -> List[str]:
|
|
def keys(self) -> List[str]:
|
|
"""Return a list of keys for accessing metrics."""
|
|
"""Return a list of keys for accessing metrics."""
|
|
- return [
|
|
|
|
- "metrics/precision(B)",
|
|
|
|
- "metrics/recall(B)",
|
|
|
|
- "metrics/mAP50(B)",
|
|
|
|
- "metrics/mAP50-95(B)",
|
|
|
|
|
|
+ return DetMetrics.keys.fget(self) + [
|
|
"metrics/precision(M)",
|
|
"metrics/precision(M)",
|
|
"metrics/recall(M)",
|
|
"metrics/recall(M)",
|
|
"metrics/mAP50(M)",
|
|
"metrics/mAP50(M)",
|
|
@@ -1171,40 +1165,26 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
|
|
|
|
def mean_results(self) -> List[float]:
|
|
def mean_results(self) -> List[float]:
|
|
"""Return the mean metrics for bounding box and segmentation results."""
|
|
"""Return the mean metrics for bounding box and segmentation results."""
|
|
- return self.box.mean_results() + self.seg.mean_results()
|
|
|
|
|
|
+ return DetMetrics.mean_results(self) + self.seg.mean_results()
|
|
|
|
|
|
def class_result(self, i: int) -> List[float]:
|
|
def class_result(self, i: int) -> List[float]:
|
|
"""Return classification results for a specified class index."""
|
|
"""Return classification results for a specified class index."""
|
|
- return self.box.class_result(i) + self.seg.class_result(i)
|
|
|
|
|
|
+ return DetMetrics.class_result(self, i) + self.seg.class_result(i)
|
|
|
|
|
|
@property
|
|
@property
|
|
def maps(self) -> np.ndarray:
|
|
def maps(self) -> np.ndarray:
|
|
"""Return mAP scores for object detection and semantic segmentation models."""
|
|
"""Return mAP scores for object detection and semantic segmentation models."""
|
|
- return self.box.maps + self.seg.maps
|
|
|
|
|
|
+ return DetMetrics.maps.fget(self) + self.seg.maps
|
|
|
|
|
|
@property
|
|
@property
|
|
def fitness(self) -> float:
|
|
def fitness(self) -> float:
|
|
"""Return the fitness score for both segmentation and bounding box models."""
|
|
"""Return the fitness score for both segmentation and bounding box models."""
|
|
- return self.seg.fitness() + self.box.fitness()
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def ap_class_index(self) -> List:
|
|
|
|
- """Return the class indices (boxes and masks have the same ap_class_index)."""
|
|
|
|
- return self.box.ap_class_index
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def results_dict(self) -> Dict[str, float]:
|
|
|
|
- """Return results of object detection model for evaluation."""
|
|
|
|
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
|
|
|
|
|
+ return self.seg.fitness() + DetMetrics.fitness.fget(self)
|
|
|
|
|
|
@property
|
|
@property
|
|
def curves(self) -> List[str]:
|
|
def curves(self) -> List[str]:
|
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
- return [
|
|
|
|
- "Precision-Recall(B)",
|
|
|
|
- "F1-Confidence(B)",
|
|
|
|
- "Precision-Confidence(B)",
|
|
|
|
- "Recall-Confidence(B)",
|
|
|
|
|
|
+ return DetMetrics.curves.fget(self) + [
|
|
"Precision-Recall(M)",
|
|
"Precision-Recall(M)",
|
|
"F1-Confidence(M)",
|
|
"F1-Confidence(M)",
|
|
"Precision-Confidence(M)",
|
|
"Precision-Confidence(M)",
|
|
@@ -1214,7 +1194,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
@property
|
|
@property
|
|
def curves_results(self) -> List[List]:
|
|
def curves_results(self) -> List[List]:
|
|
"""Return dictionary of computed performance metrics and statistics."""
|
|
"""Return dictionary of computed performance metrics and statistics."""
|
|
- return self.box.curves_results + self.seg.curves_results
|
|
|
|
|
|
+ return DetMetrics.curves_results.fget(self) + self.seg.curves_results
|
|
|
|
|
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
|
"""
|
|
"""
|
|
@@ -1234,43 +1214,34 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
>>> print(seg_summary)
|
|
>>> print(seg_summary)
|
|
"""
|
|
"""
|
|
scalars = {
|
|
scalars = {
|
|
- "box-map": round(self.box.map, decimals),
|
|
|
|
- "box-map50": round(self.box.map50, decimals),
|
|
|
|
- "box-map75": round(self.box.map75, decimals),
|
|
|
|
"mask-map": round(self.seg.map, decimals),
|
|
"mask-map": round(self.seg.map, decimals),
|
|
"mask-map50": round(self.seg.map50, decimals),
|
|
"mask-map50": round(self.seg.map50, decimals),
|
|
"mask-map75": round(self.seg.map75, decimals),
|
|
"mask-map75": round(self.seg.map75, decimals),
|
|
}
|
|
}
|
|
per_class = {
|
|
per_class = {
|
|
- "box-p": self.box.p,
|
|
|
|
- "box-r": self.box.r,
|
|
|
|
- "box-f1": self.box.f1,
|
|
|
|
"mask-p": self.seg.p,
|
|
"mask-p": self.seg.p,
|
|
"mask-r": self.seg.r,
|
|
"mask-r": self.seg.r,
|
|
"mask-f1": self.seg.f1,
|
|
"mask-f1": self.seg.f1,
|
|
}
|
|
}
|
|
- return [
|
|
|
|
- {
|
|
|
|
- "class_name": self.names[self.ap_class_index[i]],
|
|
|
|
- **{k: round(v[i], decimals) for k, v in per_class.items()},
|
|
|
|
- **scalars,
|
|
|
|
- }
|
|
|
|
- for i in range(len(per_class["box-p"]))
|
|
|
|
- ]
|
|
|
|
|
|
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
|
|
|
+ for i, s in enumerate(summary):
|
|
|
|
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
|
|
|
|
+ return summary
|
|
|
|
|
|
|
|
|
|
-class PoseMetrics(SegmentMetrics):
|
|
|
|
|
|
+class PoseMetrics(DetMetrics):
|
|
"""
|
|
"""
|
|
Calculate and aggregate detection and pose metrics over a given set of classes.
|
|
Calculate and aggregate detection and pose metrics over a given set of classes.
|
|
|
|
|
|
Attributes:
|
|
Attributes:
|
|
- save_dir (Path): Path to the directory where the output plots should be saved.
|
|
|
|
- plot (bool): Whether to save the detection and pose plots.
|
|
|
|
names (Dict[int, str]): Dictionary of class names.
|
|
names (Dict[int, str]): Dictionary of class names.
|
|
pose (Metric): An instance of the Metric class to calculate pose metrics.
|
|
pose (Metric): An instance of the Metric class to calculate pose metrics.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
task (str): The task type, set to 'pose'.
|
|
task (str): The task type, set to 'pose'.
|
|
|
|
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
|
|
|
+ nt_per_class: Number of targets per class.
|
|
|
|
+ nt_per_image: Number of targets per image.
|
|
|
|
|
|
Methods:
|
|
Methods:
|
|
process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
|
|
process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
|
|
@@ -1282,79 +1253,50 @@ class PoseMetrics(SegmentMetrics):
|
|
results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
|
|
results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
|
|
|
|
|
|
+ def __init__(self, names: Dict[int, str] = {}) -> None:
|
|
"""
|
|
"""
|
|
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- save_dir (Path, optional): Directory to save plots.
|
|
|
|
- plot (bool, optional): Whether to plot precision-recall curves.
|
|
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
"""
|
|
"""
|
|
- super().__init__(save_dir, plot, names)
|
|
|
|
- self.save_dir = save_dir
|
|
|
|
- self.plot = plot
|
|
|
|
- self.names = names
|
|
|
|
- self.box = Metric()
|
|
|
|
|
|
+ super().__init__(names)
|
|
self.pose = Metric()
|
|
self.pose = Metric()
|
|
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
|
|
self.task = "pose"
|
|
self.task = "pose"
|
|
|
|
+ self.stats["tp_p"] = [] # add additional stats for pose
|
|
|
|
|
|
- def process(
|
|
|
|
- self,
|
|
|
|
- tp: np.ndarray,
|
|
|
|
- tp_p: np.ndarray,
|
|
|
|
- conf: np.ndarray,
|
|
|
|
- pred_cls: np.ndarray,
|
|
|
|
- target_cls: np.ndarray,
|
|
|
|
- on_plot=None,
|
|
|
|
- ):
|
|
|
|
|
|
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
|
"""
|
|
"""
|
|
Process the detection and pose metrics over the given set of predictions.
|
|
Process the detection and pose metrics over the given set of predictions.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- tp (np.ndarray): True positive array for boxes.
|
|
|
|
- tp_p (np.ndarray): True positive array for keypoints.
|
|
|
|
- conf (np.ndarray): Confidence array.
|
|
|
|
- pred_cls (np.ndarray): Predicted class indices array.
|
|
|
|
- target_cls (np.ndarray): Target class indices array.
|
|
|
|
|
|
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
|
|
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
|
on_plot (callable, optional): Function to call after plots are generated.
|
|
on_plot (callable, optional): Function to call after plots are generated.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
|
"""
|
|
"""
|
|
|
|
+ stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
|
|
results_pose = ap_per_class(
|
|
results_pose = ap_per_class(
|
|
- tp_p,
|
|
|
|
- conf,
|
|
|
|
- pred_cls,
|
|
|
|
- target_cls,
|
|
|
|
- plot=self.plot,
|
|
|
|
|
|
+ stats["tp_p"],
|
|
|
|
+ stats["conf"],
|
|
|
|
+ stats["pred_cls"],
|
|
|
|
+ stats["target_cls"],
|
|
|
|
+ plot=plot,
|
|
on_plot=on_plot,
|
|
on_plot=on_plot,
|
|
- save_dir=self.save_dir,
|
|
|
|
|
|
+ save_dir=save_dir,
|
|
names=self.names,
|
|
names=self.names,
|
|
prefix="Pose",
|
|
prefix="Pose",
|
|
)[2:]
|
|
)[2:]
|
|
self.pose.nc = len(self.names)
|
|
self.pose.nc = len(self.names)
|
|
self.pose.update(results_pose)
|
|
self.pose.update(results_pose)
|
|
- results_box = ap_per_class(
|
|
|
|
- tp,
|
|
|
|
- conf,
|
|
|
|
- pred_cls,
|
|
|
|
- target_cls,
|
|
|
|
- plot=self.plot,
|
|
|
|
- on_plot=on_plot,
|
|
|
|
- save_dir=self.save_dir,
|
|
|
|
- names=self.names,
|
|
|
|
- prefix="Box",
|
|
|
|
- )[2:]
|
|
|
|
- self.box.nc = len(self.names)
|
|
|
|
- self.box.update(results_box)
|
|
|
|
|
|
+ return stats
|
|
|
|
|
|
@property
|
|
@property
|
|
def keys(self) -> List[str]:
|
|
def keys(self) -> List[str]:
|
|
"""Return list of evaluation metric keys."""
|
|
"""Return list of evaluation metric keys."""
|
|
- return [
|
|
|
|
- "metrics/precision(B)",
|
|
|
|
- "metrics/recall(B)",
|
|
|
|
- "metrics/mAP50(B)",
|
|
|
|
- "metrics/mAP50-95(B)",
|
|
|
|
|
|
+ return DetMetrics.keys.fget(self) + [
|
|
"metrics/precision(P)",
|
|
"metrics/precision(P)",
|
|
"metrics/recall(P)",
|
|
"metrics/recall(P)",
|
|
"metrics/mAP50(P)",
|
|
"metrics/mAP50(P)",
|
|
@@ -1363,26 +1305,26 @@ class PoseMetrics(SegmentMetrics):
|
|
|
|
|
|
def mean_results(self) -> List[float]:
|
|
def mean_results(self) -> List[float]:
|
|
"""Return the mean results of box and pose."""
|
|
"""Return the mean results of box and pose."""
|
|
- return self.box.mean_results() + self.pose.mean_results()
|
|
|
|
|
|
+ return DetMetrics.mean_results(self) + self.pose.mean_results()
|
|
|
|
|
|
def class_result(self, i: int) -> List[float]:
|
|
def class_result(self, i: int) -> List[float]:
|
|
"""Return the class-wise detection results for a specific class i."""
|
|
"""Return the class-wise detection results for a specific class i."""
|
|
- return self.box.class_result(i) + self.pose.class_result(i)
|
|
|
|
|
|
+ return DetMetrics.class_result(self, i) + self.pose.class_result(i)
|
|
|
|
|
|
@property
|
|
@property
|
|
def maps(self) -> np.ndarray:
|
|
def maps(self) -> np.ndarray:
|
|
"""Return the mean average precision (mAP) per class for both box and pose detections."""
|
|
"""Return the mean average precision (mAP) per class for both box and pose detections."""
|
|
- return self.box.maps + self.pose.maps
|
|
|
|
|
|
+ return DetMetrics.maps.fget(self) + self.pose.maps
|
|
|
|
|
|
@property
|
|
@property
|
|
def fitness(self) -> float:
|
|
def fitness(self) -> float:
|
|
"""Return combined fitness score for pose and box detection."""
|
|
"""Return combined fitness score for pose and box detection."""
|
|
- return self.pose.fitness() + self.box.fitness()
|
|
|
|
|
|
+ return self.pose.fitness() + DetMetrics.fitness.fget(self)
|
|
|
|
|
|
@property
|
|
@property
|
|
def curves(self) -> List[str]:
|
|
def curves(self) -> List[str]:
|
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
- return [
|
|
|
|
|
|
+ return DetMetrics.curves.fget(self) + [
|
|
"Precision-Recall(B)",
|
|
"Precision-Recall(B)",
|
|
"F1-Confidence(B)",
|
|
"F1-Confidence(B)",
|
|
"Precision-Confidence(B)",
|
|
"Precision-Confidence(B)",
|
|
@@ -1396,7 +1338,7 @@ class PoseMetrics(SegmentMetrics):
|
|
@property
|
|
@property
|
|
def curves_results(self) -> List[List]:
|
|
def curves_results(self) -> List[List]:
|
|
"""Return dictionary of computed performance metrics and statistics."""
|
|
"""Return dictionary of computed performance metrics and statistics."""
|
|
- return self.box.curves_results + self.pose.curves_results
|
|
|
|
|
|
+ return DetMetrics.curves_results.fget(self) + self.pose.curves_results
|
|
|
|
|
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
|
"""
|
|
"""
|
|
@@ -1416,29 +1358,19 @@ class PoseMetrics(SegmentMetrics):
|
|
>>> print(pose_summary)
|
|
>>> print(pose_summary)
|
|
"""
|
|
"""
|
|
scalars = {
|
|
scalars = {
|
|
- "box-map": round(self.box.map, decimals),
|
|
|
|
- "box-map50": round(self.box.map50, decimals),
|
|
|
|
- "box-map75": round(self.box.map75, decimals),
|
|
|
|
"pose-map": round(self.pose.map, decimals),
|
|
"pose-map": round(self.pose.map, decimals),
|
|
"pose-map50": round(self.pose.map50, decimals),
|
|
"pose-map50": round(self.pose.map50, decimals),
|
|
"pose-map75": round(self.pose.map75, decimals),
|
|
"pose-map75": round(self.pose.map75, decimals),
|
|
}
|
|
}
|
|
per_class = {
|
|
per_class = {
|
|
- "box-p": self.box.p,
|
|
|
|
- "box-r": self.box.r,
|
|
|
|
- "box-f1": self.box.f1,
|
|
|
|
"pose-p": self.pose.p,
|
|
"pose-p": self.pose.p,
|
|
"pose-r": self.pose.r,
|
|
"pose-r": self.pose.r,
|
|
"pose-f1": self.pose.f1,
|
|
"pose-f1": self.pose.f1,
|
|
}
|
|
}
|
|
- return [
|
|
|
|
- {
|
|
|
|
- "class_name": self.names[self.ap_class_index[i]],
|
|
|
|
- **{k: round(v[i], decimals) for k, v in per_class.items()},
|
|
|
|
- **scalars,
|
|
|
|
- }
|
|
|
|
- for i in range(len(per_class["box-p"]))
|
|
|
|
- ]
|
|
|
|
|
|
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
|
|
|
+ for i, s in enumerate(summary):
|
|
|
|
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
|
|
|
|
+ return summary
|
|
|
|
|
|
|
|
|
|
class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
@@ -1516,133 +1448,30 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
|
|
return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
|
|
|
|
|
|
|
|
|
|
-class OBBMetrics(SimpleClass, DataExportMixin):
|
|
|
|
|
|
+class OBBMetrics(DetMetrics):
|
|
"""
|
|
"""
|
|
Metrics for evaluating oriented bounding box (OBB) detection.
|
|
Metrics for evaluating oriented bounding box (OBB) detection.
|
|
|
|
|
|
Attributes:
|
|
Attributes:
|
|
- save_dir (Path): Path to the directory where the output plots should be saved.
|
|
|
|
- plot (bool): Whether to save the detection plots.
|
|
|
|
names (Dict[int, str]): Dictionary of class names.
|
|
names (Dict[int, str]): Dictionary of class names.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
task (str): The task type, set to 'obb'.
|
|
task (str): The task type, set to 'obb'.
|
|
|
|
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
|
|
|
+ nt_per_class: Number of targets per class.
|
|
|
|
+ nt_per_image: Number of targets per image.
|
|
|
|
|
|
References:
|
|
References:
|
|
https://arxiv.org/pdf/2106.06072.pdf
|
|
https://arxiv.org/pdf/2106.06072.pdf
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
|
|
|
|
|
|
+ def __init__(self, names: Dict[int, str] = {}) -> None:
|
|
"""
|
|
"""
|
|
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- save_dir (Path, optional): Directory to save plots.
|
|
|
|
- plot (bool, optional): Whether to plot precision-recall curves.
|
|
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
names (Dict[int, str], optional): Dictionary of class names.
|
|
"""
|
|
"""
|
|
- self.save_dir = save_dir
|
|
|
|
- self.plot = plot
|
|
|
|
- self.names = names
|
|
|
|
- self.box = Metric()
|
|
|
|
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
|
|
|
|
+ DetMetrics.__init__(self, names)
|
|
|
|
+ # TODO: probably remove task as well
|
|
self.task = "obb"
|
|
self.task = "obb"
|
|
-
|
|
|
|
- def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
|
|
|
|
- """
|
|
|
|
- Process predicted results for object detection and update metrics.
|
|
|
|
-
|
|
|
|
- Args:
|
|
|
|
- tp (np.ndarray): True positive array.
|
|
|
|
- conf (np.ndarray): Confidence array.
|
|
|
|
- pred_cls (np.ndarray): Predicted class indices array.
|
|
|
|
- target_cls (np.ndarray): Target class indices array.
|
|
|
|
- on_plot (callable, optional): Function to call after plots are generated.
|
|
|
|
- """
|
|
|
|
- results = ap_per_class(
|
|
|
|
- tp,
|
|
|
|
- conf,
|
|
|
|
- pred_cls,
|
|
|
|
- target_cls,
|
|
|
|
- plot=self.plot,
|
|
|
|
- save_dir=self.save_dir,
|
|
|
|
- names=self.names,
|
|
|
|
- on_plot=on_plot,
|
|
|
|
- )[2:]
|
|
|
|
- self.box.nc = len(self.names)
|
|
|
|
- self.box.update(results)
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def keys(self) -> List[str]:
|
|
|
|
- """Return a list of keys for accessing specific metrics."""
|
|
|
|
- return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
|
|
|
-
|
|
|
|
- def mean_results(self) -> List[float]:
|
|
|
|
- """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
|
|
|
- return self.box.mean_results()
|
|
|
|
-
|
|
|
|
- def class_result(self, i: int) -> Tuple[float, float, float, float]:
|
|
|
|
- """Return the result of evaluating the performance of an object detection model on a specific class."""
|
|
|
|
- return self.box.class_result(i)
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def maps(self) -> np.ndarray:
|
|
|
|
- """Return mean Average Precision (mAP) scores per class."""
|
|
|
|
- return self.box.maps
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def fitness(self) -> float:
|
|
|
|
- """Return the fitness of box object."""
|
|
|
|
- return self.box.fitness()
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def ap_class_index(self) -> List:
|
|
|
|
- """Return the average precision index per class."""
|
|
|
|
- return self.box.ap_class_index
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def results_dict(self) -> Dict[str, float]:
|
|
|
|
- """Return dictionary of computed performance metrics and statistics."""
|
|
|
|
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def curves(self) -> List:
|
|
|
|
- """Return a list of curves for accessing specific metrics curves."""
|
|
|
|
- return []
|
|
|
|
-
|
|
|
|
- @property
|
|
|
|
- def curves_results(self) -> List:
|
|
|
|
- """Return a list of curves for accessing specific metrics curves."""
|
|
|
|
- return []
|
|
|
|
-
|
|
|
|
- def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
|
|
|
- """
|
|
|
|
- Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
|
|
|
|
- scalar metrics (mAP, mAP50, mAP75) along with precision, recall, and F1-score for each class.
|
|
|
|
-
|
|
|
|
- Args:
|
|
|
|
- normalize (bool): For OBB metrics, everything is normalized by default [0-1].
|
|
|
|
- decimals (int): Number of decimal places to round the metrics values to.
|
|
|
|
-
|
|
|
|
- Returns:
|
|
|
|
- (List[Dict[str, Union[str, float]]]): A list of dictionaries, each representing one class with detection metrics.
|
|
|
|
-
|
|
|
|
- Examples:
|
|
|
|
- >>> results = model.val(data="dota8.yaml")
|
|
|
|
- >>> detection_summary = results.summary(decimals=4)
|
|
|
|
- >>> print(detection_summary)
|
|
|
|
- """
|
|
|
|
- scalars = {
|
|
|
|
- "box-map": round(self.box.map, decimals),
|
|
|
|
- "box-map50": round(self.box.map50, decimals),
|
|
|
|
- "box-map75": round(self.box.map75, decimals),
|
|
|
|
- }
|
|
|
|
- per_class = {"box-p": self.box.p, "box-r": self.box.r, "box-f1": self.box.f1}
|
|
|
|
- return [
|
|
|
|
- {
|
|
|
|
- "class_name": self.names[self.ap_class_index[i]],
|
|
|
|
- **{k: round(v[i], decimals) for k, v in per_class.items()},
|
|
|
|
- **scalars,
|
|
|
|
- }
|
|
|
|
- for i in range(len(per_class["box-p"]))
|
|
|
|
- ]
|
|
|