1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
|
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- from multiprocessing.pool import ThreadPool
- from pathlib import Path
- from typing import Any, Dict, List, Tuple
- import numpy as np
- import torch
- import torch.nn.functional as F
- from ultralytics.models.yolo.detect import DetectionValidator
- from ultralytics.utils import LOGGER, NUM_THREADS, ops
- from ultralytics.utils.checks import check_requirements
- from ultralytics.utils.metrics import SegmentMetrics, mask_iou
- class SegmentationValidator(DetectionValidator):
- """
- A class extending the DetectionValidator class for validation based on a segmentation model.
- This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
- to compute metrics such as mAP for both detection and segmentation tasks.
- Attributes:
- plot_masks (list): List to store masks for plotting.
- process (callable): Function to process masks based on save_json and save_txt flags.
- args (namespace): Arguments for the validator.
- metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
- stats (dict): Dictionary to store statistics during validation.
- Examples:
- >>> from ultralytics.models.yolo.segment import SegmentationValidator
- >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
- >>> validator = SegmentationValidator(args=args)
- >>> validator()
- """
- def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
- """
- Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
- Args:
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
- save_dir (Path, optional): Directory to save results.
- args (namespace, optional): Arguments for the validator.
- _callbacks (list, optional): List of callback functions.
- """
- super().__init__(dataloader, save_dir, args, _callbacks)
- self.process = None
- self.args.task = "segment"
- self.metrics = SegmentMetrics()
- def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
- """
- Preprocess batch of images for YOLO segmentation validation.
- Args:
- batch (Dict[str, Any]): Batch containing images and annotations.
- Returns:
- (Dict[str, Any]): Preprocessed batch.
- """
- batch = super().preprocess(batch)
- batch["masks"] = batch["masks"].to(self.device).float()
- return batch
- def init_metrics(self, model: torch.nn.Module) -> None:
- """
- Initialize metrics and select mask processing function based on save_json flag.
- Args:
- model (torch.nn.Module): Model to validate.
- """
- super().init_metrics(model)
- if self.args.save_json:
- check_requirements("pycocotools>=2.0.6")
- # More accurate vs faster
- self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
- def get_desc(self) -> str:
- """Return a formatted description of evaluation metrics."""
- return ("%22s" + "%11s" * 10) % (
- "Class",
- "Images",
- "Instances",
- "Box(P",
- "R",
- "mAP50",
- "mAP50-95)",
- "Mask(P",
- "R",
- "mAP50",
- "mAP50-95)",
- )
- def postprocess(self, preds: List[torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
- """
- Post-process YOLO predictions and return output detections with proto.
- Args:
- preds (List[torch.Tensor]): Raw predictions from the model.
- Returns:
- List[Dict[str, torch.Tensor]]: Processed detection predictions with masks.
- """
- proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
- preds = super().postprocess(preds[0])
- imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
- for i, pred in enumerate(preds):
- coefficient = pred.pop("extra")
- pred["masks"] = (
- self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
- if len(coefficient)
- else torch.zeros((0, imgsz[0], imgsz[1]), dtype=torch.uint8, device=pred["bboxes"].device)
- )
- return preds
- def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
- """
- Prepare a batch for training or inference by processing images and targets.
- Args:
- si (int): Batch index.
- batch (Dict[str, Any]): Batch data containing images and annotations.
- Returns:
- (Dict[str, Any]): Prepared batch with processed annotations.
- """
- prepared_batch = super()._prepare_batch(si, batch)
- midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
- prepared_batch["masks"] = batch["masks"][midx]
- return prepared_batch
- def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
- """
- Prepare predictions for evaluation by processing bounding boxes and masks.
- Args:
- pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
- pbatch (Dict[str, Any]): Prepared batch information.
- Returns:
- Dict[str, torch.Tensor]: Processed bounding box predictions.
- """
- predn = super()._prepare_pred(pred, pbatch)
- predn["masks"] = pred["masks"]
- if self.args.save_json and len(predn["masks"]):
- coco_masks = torch.as_tensor(pred["masks"], dtype=torch.uint8)
- coco_masks = ops.scale_image(
- coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
- pbatch["ori_shape"],
- ratio_pad=pbatch["ratio_pad"],
- )
- predn["coco_masks"] = coco_masks
- return predn
- def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
- """
- Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
- Args:
- preds (Dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
- batch (Dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
- Returns:
- (Dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
- Notes:
- - If `masks` is True, the function computes IoU between predicted and ground truth masks.
- - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
- Examples:
- >>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
- >>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
- >>> correct_preds = validator._process_batch(preds, batch)
- """
- tp = super()._process_batch(preds, batch)
- gt_cls, gt_masks = batch["cls"], batch["masks"]
- if len(gt_cls) == 0 or len(preds["cls"]) == 0:
- tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
- else:
- pred_masks = preds["masks"]
- if self.args.overlap_mask:
- nl = len(gt_cls)
- index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
- gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
- gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
- if gt_masks.shape[1:] != pred_masks.shape[1:]:
- gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
- gt_masks = gt_masks.gt_(0.5)
- iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
- tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
- tp.update({"tp_m": tp_m}) # update tp with mask IoU
- return tp
- def plot_predictions(self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int) -> None:
- """
- Plot batch predictions with masks and bounding boxes.
- Args:
- batch (Dict[str, Any]): Batch containing images and annotations.
- preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
- ni (int): Batch index.
- """
- for p in preds:
- masks = p["masks"]
- if masks.shape[0] > 50:
- LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
- p["masks"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()
- super().plot_predictions(batch, preds, ni, max_det=50) # plot bboxes
- def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
- """
- Save YOLO detections to a txt file in normalized coordinates in a specific format.
- Args:
- predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
- save_conf (bool): Whether to save confidence scores.
- shape (Tuple[int, int]): Shape of the original image.
- file (Path): File path to save the detections.
- """
- from ultralytics.engine.results import Results
- Results(
- np.zeros((shape[0], shape[1]), dtype=np.uint8),
- path=None,
- names=self.names,
- boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
- masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
- ).save_txt(file, save_conf=save_conf)
- def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
- """
- Save one JSON result for COCO evaluation.
- Args:
- predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
- filename (str): Image filename.
- Examples:
- >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
- """
- from pycocotools.mask import encode # noqa
- def single_encode(x):
- """Encode predicted masks as RLE and append results to jdict."""
- rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
- rle["counts"] = rle["counts"].decode("utf-8")
- return rle
- stem = Path(filename).stem
- image_id = int(stem) if stem.isnumeric() else stem
- box = ops.xyxy2xywh(predn["bboxes"]) # xywh
- box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
- pred_masks = np.transpose(predn["coco_masks"], (2, 0, 1))
- with ThreadPool(NUM_THREADS) as pool:
- rles = pool.map(single_encode, pred_masks)
- for i, (b, s, c) in enumerate(zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist())):
- self.jdict.append(
- {
- "image_id": image_id,
- "category_id": self.class_map[int(c)],
- "bbox": [round(x, 3) for x in b],
- "score": round(s, 5),
- "segmentation": rles[i],
- }
- )
- def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
- """Return COCO-style instance segmentation evaluation metrics."""
- if self.args.save_json and (self.is_lvis or self.is_coco) and len(self.jdict):
- pred_json = self.save_dir / "predictions.json" # predictions
- anno_json = (
- self.data["path"]
- / "annotations"
- / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
- ) # annotations
- pkg = "pycocotools" if self.is_coco else "lvis"
- LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
- try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
- for x in anno_json, pred_json:
- assert x.is_file(), f"{x} file not found"
- check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
- if self.is_coco:
- from pycocotools.coco import COCO # noqa
- from pycocotools.cocoeval import COCOeval # noqa
- anno = COCO(str(anno_json)) # init annotations api
- pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
- vals = [COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]
- else:
- from lvis import LVIS, LVISEval
- anno = LVIS(str(anno_json))
- pred = anno._load_json(str(pred_json))
- vals = [LVISEval(anno, pred, "bbox"), LVISEval(anno, pred, "segm")]
- for i, eval in enumerate(vals):
- eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
- eval.evaluate()
- eval.accumulate()
- eval.summarize()
- if self.is_lvis:
- eval.print_results()
- idx = i * 4 + 2
- # update mAP50-95 and mAP50
- stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = (
- eval.stats[:2] if self.is_coco else [eval.results["AP"], eval.results["AP50"]]
- )
- if self.is_lvis:
- tag = "B" if i == 0 else "M"
- stats[f"metrics/APr({tag})"] = eval.results["APr"]
- stats[f"metrics/APc({tag})"] = eval.results["APc"]
- stats[f"metrics/APf({tag})"] = eval.results["APf"]
- if self.is_lvis:
- stats["fitness"] = stats["metrics/mAP50-95(B)"]
- except Exception as e:
- LOGGER.warning(f"{pkg} unable to run: {e}")
- return stats
|