1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
|
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- """
- Check a model's accuracy on a test or val split of a dataset.
- Usage:
- $ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640
- Usage - formats:
- $ yolo mode=val model=yolo11n.pt # PyTorch
- yolo11n.torchscript # TorchScript
- yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
- yolo11n_openvino_model # OpenVINO
- yolo11n.engine # TensorRT
- yolo11n.mlpackage # CoreML (macOS-only)
- yolo11n_saved_model # TensorFlow SavedModel
- yolo11n.pb # TensorFlow GraphDef
- yolo11n.tflite # TensorFlow Lite
- yolo11n_edgetpu.tflite # TensorFlow Edge TPU
- yolo11n_paddle_model # PaddlePaddle
- yolo11n.mnn # MNN
- yolo11n_ncnn_model # NCNN
- yolo11n_imx_model # Sony IMX
- yolo11n_rknn_model # Rockchip RKNN
- """
- import json
- import time
- from pathlib import Path
- import numpy as np
- import torch
- from ultralytics.cfg import get_cfg, get_save_dir
- from ultralytics.data.utils import check_cls_dataset, check_det_dataset
- from ultralytics.nn.autobackend import AutoBackend
- from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
- from ultralytics.utils.checks import check_imgsz
- from ultralytics.utils.ops import Profile
- from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
- class BaseValidator:
- """
- A base class for creating validators.
- This class provides the foundation for validation processes, including model evaluation, metric computation, and
- result visualization.
- Attributes:
- args (SimpleNamespace): Configuration for the validator.
- dataloader (DataLoader): Dataloader to use for validation.
- model (nn.Module): Model to validate.
- data (dict): Data dictionary containing dataset information.
- device (torch.device): Device to use for validation.
- batch_i (int): Current batch index.
- training (bool): Whether the model is in training mode.
- names (dict): Class names mapping.
- seen (int): Number of images seen so far during validation.
- stats (dict): Statistics collected during validation.
- confusion_matrix: Confusion matrix for classification evaluation.
- nc (int): Number of classes.
- iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
- jdict (list): List to store JSON validation results.
- speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
- batch processing times in milliseconds.
- save_dir (Path): Directory to save results.
- plots (dict): Dictionary to store plots for visualization.
- callbacks (dict): Dictionary to store various callback functions.
- stride (int): Model stride for padding calculations.
- loss (torch.Tensor): Accumulated loss during training validation.
- Methods:
- __call__: Execute validation process, running inference on dataloader and computing performance metrics.
- match_predictions: Match predictions to ground truth objects using IoU.
- add_callback: Append the given callback to the specified event.
- run_callbacks: Run all callbacks associated with a specified event.
- get_dataloader: Get data loader from dataset path and batch size.
- build_dataset: Build dataset from image path.
- preprocess: Preprocess an input batch.
- postprocess: Postprocess the predictions.
- init_metrics: Initialize performance metrics for the YOLO model.
- update_metrics: Update metrics based on predictions and batch.
- finalize_metrics: Finalize and return all metrics.
- get_stats: Return statistics about the model's performance.
- print_results: Print the results of the model's predictions.
- get_desc: Get description of the YOLO model.
- on_plot: Register plots for visualization.
- plot_val_samples: Plot validation samples during training.
- plot_predictions: Plot YOLO model predictions on batch images.
- pred_to_json: Convert predictions to JSON format.
- eval_json: Evaluate and return JSON format of prediction statistics.
- """
- def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
- """
- Initialize a BaseValidator instance.
- Args:
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
- save_dir (Path, optional): Directory to save results.
- args (SimpleNamespace, optional): Configuration for the validator.
- _callbacks (dict, optional): Dictionary to store various callback functions.
- """
- self.args = get_cfg(overrides=args)
- self.dataloader = dataloader
- self.stride = None
- self.data = None
- self.device = None
- self.batch_i = None
- self.training = True
- self.names = None
- self.seen = None
- self.stats = None
- self.confusion_matrix = None
- self.nc = None
- self.iouv = None
- self.jdict = None
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
- self.save_dir = save_dir or get_save_dir(self.args)
- (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
- if self.args.conf is None:
- self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
- self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
- self.plots = {}
- self.callbacks = _callbacks or callbacks.get_default_callbacks()
- @smart_inference_mode()
- def __call__(self, trainer=None, model=None):
- """
- Execute validation process, running inference on dataloader and computing performance metrics.
- Args:
- trainer (object, optional): Trainer object that contains the model to validate.
- model (nn.Module, optional): Model to validate if not using a trainer.
- Returns:
- (dict): Dictionary containing validation statistics.
- """
- self.training = trainer is not None
- augment = self.args.augment and (not self.training)
- if self.training:
- self.device = trainer.device
- self.data = trainer.data
- # Force FP16 val during training
- self.args.half = self.device.type != "cpu" and trainer.amp
- model = trainer.ema.ema or trainer.model
- model = model.half() if self.args.half else model.float()
- self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
- self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
- model.eval()
- else:
- if str(self.args.model).endswith(".yaml") and model is None:
- LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
- callbacks.add_integration_callbacks(self)
- model = AutoBackend(
- weights=model or self.args.model,
- device=select_device(self.args.device, self.args.batch),
- dnn=self.args.dnn,
- data=self.args.data,
- fp16=self.args.half,
- )
- self.device = model.device # update device
- self.args.half = model.fp16 # update half
- stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
- imgsz = check_imgsz(self.args.imgsz, stride=stride)
- if engine:
- self.args.batch = model.batch_size
- elif not (pt or jit or getattr(model, "dynamic", False)):
- self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
- LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
- if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
- self.data = check_det_dataset(self.args.data)
- elif self.args.task == "classify":
- self.data = check_cls_dataset(self.args.data, split=self.args.split)
- else:
- raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
- if self.device.type in {"cpu", "mps"}:
- self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
- if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
- self.args.rect = False
- self.stride = model.stride # used in get_dataloader() for padding
- self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
- model.eval()
- model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
- self.run_callbacks("on_val_start")
- dt = (
- Profile(device=self.device),
- Profile(device=self.device),
- Profile(device=self.device),
- Profile(device=self.device),
- )
- bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
- self.init_metrics(de_parallel(model))
- self.jdict = [] # empty before each val
- for batch_i, batch in enumerate(bar):
- self.run_callbacks("on_val_batch_start")
- self.batch_i = batch_i
- # Preprocess
- with dt[0]:
- batch = self.preprocess(batch)
- # Inference
- with dt[1]:
- preds = model(batch["img"], augment=augment)
- # Loss
- with dt[2]:
- if self.training:
- self.loss += model.loss(batch, preds)[1]
- # Postprocess
- with dt[3]:
- preds = self.postprocess(preds)
- self.update_metrics(preds, batch)
- if self.args.plots and batch_i < 3:
- self.plot_val_samples(batch, batch_i)
- self.plot_predictions(batch, preds, batch_i)
- self.run_callbacks("on_val_batch_end")
- stats = self.get_stats()
- self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
- self.finalize_metrics()
- self.print_results()
- self.run_callbacks("on_val_end")
- if self.training:
- model.float()
- results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
- return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
- else:
- LOGGER.info(
- "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
- *tuple(self.speed.values())
- )
- )
- if self.args.save_json and self.jdict:
- with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
- LOGGER.info(f"Saving {f.name}...")
- json.dump(self.jdict, f) # flatten and save
- stats = self.eval_json(stats) # update stats
- if self.args.plots or self.args.save_json:
- LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
- return stats
- def match_predictions(
- self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
- ) -> torch.Tensor:
- """
- Match predictions to ground truth objects using IoU.
- Args:
- pred_classes (torch.Tensor): Predicted class indices of shape (N,).
- true_classes (torch.Tensor): Target class indices of shape (M,).
- iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
- use_scipy (bool, optional): Whether to use scipy for matching (more precise).
- Returns:
- (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
- """
- # Dx10 matrix, where D - detections, 10 - IoU thresholds
- correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
- # LxD matrix where L - labels (rows), D - detections (columns)
- correct_class = true_classes[:, None] == pred_classes
- iou = iou * correct_class # zero out the wrong classes
- iou = iou.cpu().numpy()
- for i, threshold in enumerate(self.iouv.cpu().tolist()):
- if use_scipy:
- # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
- import scipy # scope import to avoid importing for all commands
- cost_matrix = iou * (iou >= threshold)
- if cost_matrix.any():
- labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
- valid = cost_matrix[labels_idx, detections_idx] > 0
- if valid.any():
- correct[detections_idx[valid], i] = True
- else:
- matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
- matches = np.array(matches).T
- if matches.shape[0]:
- if matches.shape[0] > 1:
- matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
- matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
- matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
- correct[matches[:, 1].astype(int), i] = True
- return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
- def add_callback(self, event: str, callback):
- """Append the given callback to the specified event."""
- self.callbacks[event].append(callback)
- def run_callbacks(self, event: str):
- """Run all callbacks associated with a specified event."""
- for callback in self.callbacks.get(event, []):
- callback(self)
- def get_dataloader(self, dataset_path, batch_size):
- """Get data loader from dataset path and batch size."""
- raise NotImplementedError("get_dataloader function not implemented for this validator")
- def build_dataset(self, img_path):
- """Build dataset from image path."""
- raise NotImplementedError("build_dataset function not implemented in validator")
- def preprocess(self, batch):
- """Preprocess an input batch."""
- return batch
- def postprocess(self, preds):
- """Postprocess the predictions."""
- return preds
- def init_metrics(self, model):
- """Initialize performance metrics for the YOLO model."""
- pass
- def update_metrics(self, preds, batch):
- """Update metrics based on predictions and batch."""
- pass
- def finalize_metrics(self):
- """Finalize and return all metrics."""
- pass
- def get_stats(self):
- """Return statistics about the model's performance."""
- return {}
- def print_results(self):
- """Print the results of the model's predictions."""
- pass
- def get_desc(self):
- """Get description of the YOLO model."""
- pass
- @property
- def metric_keys(self):
- """Return the metric keys used in YOLO training/validation."""
- return []
- def on_plot(self, name, data=None):
- """Register plots for visualization."""
- self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
- def plot_val_samples(self, batch, ni):
- """Plot validation samples during training."""
- pass
- def plot_predictions(self, batch, preds, ni):
- """Plot YOLO model predictions on batch images."""
- pass
- def pred_to_json(self, preds, batch):
- """Convert predictions to JSON format."""
- pass
- def eval_json(self, stats):
- """Evaluate and return JSON format of prediction statistics."""
- pass
|