Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

validator.py 16 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """
  3. Check a model's accuracy on a test or val split of a dataset.
  4. Usage:
  5. $ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640
  6. Usage - formats:
  7. $ yolo mode=val model=yolo11n.pt # PyTorch
  8. yolo11n.torchscript # TorchScript
  9. yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
  10. yolo11n_openvino_model # OpenVINO
  11. yolo11n.engine # TensorRT
  12. yolo11n.mlpackage # CoreML (macOS-only)
  13. yolo11n_saved_model # TensorFlow SavedModel
  14. yolo11n.pb # TensorFlow GraphDef
  15. yolo11n.tflite # TensorFlow Lite
  16. yolo11n_edgetpu.tflite # TensorFlow Edge TPU
  17. yolo11n_paddle_model # PaddlePaddle
  18. yolo11n.mnn # MNN
  19. yolo11n_ncnn_model # NCNN
  20. yolo11n_imx_model # Sony IMX
  21. yolo11n_rknn_model # Rockchip RKNN
  22. """
  23. import json
  24. import time
  25. from pathlib import Path
  26. import numpy as np
  27. import torch
  28. from ultralytics.cfg import get_cfg, get_save_dir
  29. from ultralytics.data.utils import check_cls_dataset, check_det_dataset
  30. from ultralytics.nn.autobackend import AutoBackend
  31. from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
  32. from ultralytics.utils.checks import check_imgsz
  33. from ultralytics.utils.ops import Profile
  34. from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
  35. class BaseValidator:
  36. """
  37. A base class for creating validators.
  38. This class provides the foundation for validation processes, including model evaluation, metric computation, and
  39. result visualization.
  40. Attributes:
  41. args (SimpleNamespace): Configuration for the validator.
  42. dataloader (DataLoader): Dataloader to use for validation.
  43. model (nn.Module): Model to validate.
  44. data (dict): Data dictionary containing dataset information.
  45. device (torch.device): Device to use for validation.
  46. batch_i (int): Current batch index.
  47. training (bool): Whether the model is in training mode.
  48. names (dict): Class names mapping.
  49. seen (int): Number of images seen so far during validation.
  50. stats (dict): Statistics collected during validation.
  51. confusion_matrix: Confusion matrix for classification evaluation.
  52. nc (int): Number of classes.
  53. iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
  54. jdict (list): List to store JSON validation results.
  55. speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
  56. batch processing times in milliseconds.
  57. save_dir (Path): Directory to save results.
  58. plots (dict): Dictionary to store plots for visualization.
  59. callbacks (dict): Dictionary to store various callback functions.
  60. stride (int): Model stride for padding calculations.
  61. loss (torch.Tensor): Accumulated loss during training validation.
  62. Methods:
  63. __call__: Execute validation process, running inference on dataloader and computing performance metrics.
  64. match_predictions: Match predictions to ground truth objects using IoU.
  65. add_callback: Append the given callback to the specified event.
  66. run_callbacks: Run all callbacks associated with a specified event.
  67. get_dataloader: Get data loader from dataset path and batch size.
  68. build_dataset: Build dataset from image path.
  69. preprocess: Preprocess an input batch.
  70. postprocess: Postprocess the predictions.
  71. init_metrics: Initialize performance metrics for the YOLO model.
  72. update_metrics: Update metrics based on predictions and batch.
  73. finalize_metrics: Finalize and return all metrics.
  74. get_stats: Return statistics about the model's performance.
  75. print_results: Print the results of the model's predictions.
  76. get_desc: Get description of the YOLO model.
  77. on_plot: Register plots for visualization.
  78. plot_val_samples: Plot validation samples during training.
  79. plot_predictions: Plot YOLO model predictions on batch images.
  80. pred_to_json: Convert predictions to JSON format.
  81. eval_json: Evaluate and return JSON format of prediction statistics.
  82. """
  83. def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
  84. """
  85. Initialize a BaseValidator instance.
  86. Args:
  87. dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
  88. save_dir (Path, optional): Directory to save results.
  89. args (SimpleNamespace, optional): Configuration for the validator.
  90. _callbacks (dict, optional): Dictionary to store various callback functions.
  91. """
  92. self.args = get_cfg(overrides=args)
  93. self.dataloader = dataloader
  94. self.stride = None
  95. self.data = None
  96. self.device = None
  97. self.batch_i = None
  98. self.training = True
  99. self.names = None
  100. self.seen = None
  101. self.stats = None
  102. self.confusion_matrix = None
  103. self.nc = None
  104. self.iouv = None
  105. self.jdict = None
  106. self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
  107. self.save_dir = save_dir or get_save_dir(self.args)
  108. (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
  109. if self.args.conf is None:
  110. self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
  111. self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
  112. self.plots = {}
  113. self.callbacks = _callbacks or callbacks.get_default_callbacks()
  114. @smart_inference_mode()
  115. def __call__(self, trainer=None, model=None):
  116. """
  117. Execute validation process, running inference on dataloader and computing performance metrics.
  118. Args:
  119. trainer (object, optional): Trainer object that contains the model to validate.
  120. model (nn.Module, optional): Model to validate if not using a trainer.
  121. Returns:
  122. (dict): Dictionary containing validation statistics.
  123. """
  124. self.training = trainer is not None
  125. augment = self.args.augment and (not self.training)
  126. if self.training:
  127. self.device = trainer.device
  128. self.data = trainer.data
  129. # Force FP16 val during training
  130. self.args.half = self.device.type != "cpu" and trainer.amp
  131. model = trainer.ema.ema or trainer.model
  132. model = model.half() if self.args.half else model.float()
  133. self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
  134. self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
  135. model.eval()
  136. else:
  137. if str(self.args.model).endswith(".yaml") and model is None:
  138. LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
  139. callbacks.add_integration_callbacks(self)
  140. model = AutoBackend(
  141. weights=model or self.args.model,
  142. device=select_device(self.args.device, self.args.batch),
  143. dnn=self.args.dnn,
  144. data=self.args.data,
  145. fp16=self.args.half,
  146. )
  147. self.device = model.device # update device
  148. self.args.half = model.fp16 # update half
  149. stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
  150. imgsz = check_imgsz(self.args.imgsz, stride=stride)
  151. if engine:
  152. self.args.batch = model.batch_size
  153. elif not (pt or jit or getattr(model, "dynamic", False)):
  154. self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
  155. LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
  156. if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
  157. self.data = check_det_dataset(self.args.data)
  158. elif self.args.task == "classify":
  159. self.data = check_cls_dataset(self.args.data, split=self.args.split)
  160. else:
  161. raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
  162. if self.device.type in {"cpu", "mps"}:
  163. self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
  164. if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
  165. self.args.rect = False
  166. self.stride = model.stride # used in get_dataloader() for padding
  167. self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
  168. model.eval()
  169. model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
  170. self.run_callbacks("on_val_start")
  171. dt = (
  172. Profile(device=self.device),
  173. Profile(device=self.device),
  174. Profile(device=self.device),
  175. Profile(device=self.device),
  176. )
  177. bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
  178. self.init_metrics(de_parallel(model))
  179. self.jdict = [] # empty before each val
  180. for batch_i, batch in enumerate(bar):
  181. self.run_callbacks("on_val_batch_start")
  182. self.batch_i = batch_i
  183. # Preprocess
  184. with dt[0]:
  185. batch = self.preprocess(batch)
  186. # Inference
  187. with dt[1]:
  188. preds = model(batch["img"], augment=augment)
  189. # Loss
  190. with dt[2]:
  191. if self.training:
  192. self.loss += model.loss(batch, preds)[1]
  193. # Postprocess
  194. with dt[3]:
  195. preds = self.postprocess(preds)
  196. self.update_metrics(preds, batch)
  197. if self.args.plots and batch_i < 3:
  198. self.plot_val_samples(batch, batch_i)
  199. self.plot_predictions(batch, preds, batch_i)
  200. self.run_callbacks("on_val_batch_end")
  201. stats = self.get_stats()
  202. self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
  203. self.finalize_metrics()
  204. self.print_results()
  205. self.run_callbacks("on_val_end")
  206. if self.training:
  207. model.float()
  208. results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
  209. return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
  210. else:
  211. LOGGER.info(
  212. "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
  213. *tuple(self.speed.values())
  214. )
  215. )
  216. if self.args.save_json and self.jdict:
  217. with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
  218. LOGGER.info(f"Saving {f.name}...")
  219. json.dump(self.jdict, f) # flatten and save
  220. stats = self.eval_json(stats) # update stats
  221. if self.args.plots or self.args.save_json:
  222. LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
  223. return stats
  224. def match_predictions(
  225. self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
  226. ) -> torch.Tensor:
  227. """
  228. Match predictions to ground truth objects using IoU.
  229. Args:
  230. pred_classes (torch.Tensor): Predicted class indices of shape (N,).
  231. true_classes (torch.Tensor): Target class indices of shape (M,).
  232. iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
  233. use_scipy (bool, optional): Whether to use scipy for matching (more precise).
  234. Returns:
  235. (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
  236. """
  237. # Dx10 matrix, where D - detections, 10 - IoU thresholds
  238. correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
  239. # LxD matrix where L - labels (rows), D - detections (columns)
  240. correct_class = true_classes[:, None] == pred_classes
  241. iou = iou * correct_class # zero out the wrong classes
  242. iou = iou.cpu().numpy()
  243. for i, threshold in enumerate(self.iouv.cpu().tolist()):
  244. if use_scipy:
  245. # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
  246. import scipy # scope import to avoid importing for all commands
  247. cost_matrix = iou * (iou >= threshold)
  248. if cost_matrix.any():
  249. labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
  250. valid = cost_matrix[labels_idx, detections_idx] > 0
  251. if valid.any():
  252. correct[detections_idx[valid], i] = True
  253. else:
  254. matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
  255. matches = np.array(matches).T
  256. if matches.shape[0]:
  257. if matches.shape[0] > 1:
  258. matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
  259. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  260. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  261. correct[matches[:, 1].astype(int), i] = True
  262. return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
  263. def add_callback(self, event: str, callback):
  264. """Append the given callback to the specified event."""
  265. self.callbacks[event].append(callback)
  266. def run_callbacks(self, event: str):
  267. """Run all callbacks associated with a specified event."""
  268. for callback in self.callbacks.get(event, []):
  269. callback(self)
  270. def get_dataloader(self, dataset_path, batch_size):
  271. """Get data loader from dataset path and batch size."""
  272. raise NotImplementedError("get_dataloader function not implemented for this validator")
  273. def build_dataset(self, img_path):
  274. """Build dataset from image path."""
  275. raise NotImplementedError("build_dataset function not implemented in validator")
  276. def preprocess(self, batch):
  277. """Preprocess an input batch."""
  278. return batch
  279. def postprocess(self, preds):
  280. """Postprocess the predictions."""
  281. return preds
  282. def init_metrics(self, model):
  283. """Initialize performance metrics for the YOLO model."""
  284. pass
  285. def update_metrics(self, preds, batch):
  286. """Update metrics based on predictions and batch."""
  287. pass
  288. def finalize_metrics(self):
  289. """Finalize and return all metrics."""
  290. pass
  291. def get_stats(self):
  292. """Return statistics about the model's performance."""
  293. return {}
  294. def print_results(self):
  295. """Print the results of the model's predictions."""
  296. pass
  297. def get_desc(self):
  298. """Get description of the YOLO model."""
  299. pass
  300. @property
  301. def metric_keys(self):
  302. """Return the metric keys used in YOLO training/validation."""
  303. return []
  304. def on_plot(self, name, data=None):
  305. """Register plots for visualization."""
  306. self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
  307. def plot_val_samples(self, batch, ni):
  308. """Plot validation samples during training."""
  309. pass
  310. def plot_predictions(self, batch, preds, ni):
  311. """Plot YOLO model predictions on batch images."""
  312. pass
  313. def pred_to_json(self, preds, batch):
  314. """Convert predictions to JSON format."""
  315. pass
  316. def eval_json(self, stats):
  317. """Evaluate and return JSON format of prediction statistics."""
  318. pass
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...