Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
  1. # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
  2. import gc
  3. import math
  4. import os
  5. import random
  6. import time
  7. from contextlib import contextmanager
  8. from copy import deepcopy
  9. from datetime import datetime
  10. from pathlib import Path
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import torch.distributed as dist
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from ultralytics import __version__
  18. from ultralytics.utils import (
  19. DEFAULT_CFG_DICT,
  20. DEFAULT_CFG_KEYS,
  21. LOGGER,
  22. NUM_THREADS,
  23. PYTHON_VERSION,
  24. TORCHVISION_VERSION,
  25. WINDOWS,
  26. colorstr,
  27. )
  28. from ultralytics.utils.checks import check_version
  29. try:
  30. import thop
  31. except ImportError:
  32. thop = None # conda support without 'ultralytics-thop' installed
  33. # Version checks (all default to version>=min_version)
  34. TORCH_1_9 = check_version(torch.__version__, "1.9.0")
  35. TORCH_1_13 = check_version(torch.__version__, "1.13.0")
  36. TORCH_2_0 = check_version(torch.__version__, "2.0.0")
  37. TORCH_2_4 = check_version(torch.__version__, "2.4.0")
  38. TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
  39. TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
  40. TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
  41. TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
  42. if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
  43. LOGGER.warning(
  44. "Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
  45. "https://github.com/ultralytics/ultralytics/issues/15049"
  46. )
  47. @contextmanager
  48. def torch_distributed_zero_first(local_rank: int):
  49. """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
  50. initialized = dist.is_available() and dist.is_initialized()
  51. use_ids = initialized and dist.get_backend() == "nccl"
  52. if initialized and local_rank not in {-1, 0}:
  53. dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
  54. yield
  55. if initialized and local_rank == 0:
  56. dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
  57. def smart_inference_mode():
  58. """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
  59. def decorate(fn):
  60. """Applies appropriate torch decorator for inference mode based on torch version."""
  61. if TORCH_1_9 and torch.is_inference_mode_enabled():
  62. return fn # already in inference_mode, act as a pass-through
  63. else:
  64. return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
  65. return decorate
  66. def autocast(enabled: bool, device: str = "cuda"):
  67. """
  68. Get the appropriate autocast context manager based on PyTorch version and AMP setting.
  69. This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
  70. older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
  71. Args:
  72. enabled (bool): Whether to enable automatic mixed precision.
  73. device (str, optional): The device to use for autocast. Defaults to 'cuda'.
  74. Returns:
  75. (torch.amp.autocast): The appropriate autocast context manager.
  76. Notes:
  77. - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
  78. - For older versions, it uses `torch.cuda.autocast`.
  79. Examples:
  80. >>> with autocast(enabled=True):
  81. ... # Your mixed precision operations here
  82. ... pass
  83. """
  84. if TORCH_1_13:
  85. return torch.amp.autocast(device, enabled=enabled)
  86. else:
  87. return torch.cuda.amp.autocast(enabled)
  88. def get_cpu_info():
  89. """Return a string with system CPU information, i.e. 'Apple M2'."""
  90. from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
  91. if "cpu_info" not in PERSISTENT_CACHE:
  92. try:
  93. import cpuinfo # pip install py-cpuinfo
  94. k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
  95. info = cpuinfo.get_cpu_info() # info dict
  96. string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
  97. PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
  98. except Exception:
  99. pass
  100. return PERSISTENT_CACHE.get("cpu_info", "unknown")
  101. def get_gpu_info(index):
  102. """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
  103. properties = torch.cuda.get_device_properties(index)
  104. return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
  105. def select_device(device="", batch=0, newline=False, verbose=True):
  106. """
  107. Select the appropriate PyTorch device based on the provided arguments.
  108. The function takes a string specifying the device or a torch.device object and returns a torch.device object
  109. representing the selected device. The function also validates the number of available devices and raises an
  110. exception if the requested device(s) are not available.
  111. Args:
  112. device (str | torch.device, optional): Device string or torch.device object.
  113. Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
  114. the first available GPU, or CPU if no GPU is available.
  115. batch (int, optional): Batch size being used in your model. Defaults to 0.
  116. newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
  117. verbose (bool, optional): If True, logs the device information. Defaults to True.
  118. Returns:
  119. (torch.device): Selected device.
  120. Raises:
  121. ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
  122. devices when using multiple GPUs.
  123. Examples:
  124. >>> select_device("cuda:0")
  125. device(type='cuda', index=0)
  126. >>> select_device("cpu")
  127. device(type='cpu')
  128. Note:
  129. Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
  130. """
  131. if isinstance(device, torch.device) or str(device).startswith("tpu") or str(device).startswith("intel"):
  132. return device
  133. s = f"Ultralytics {__version__} ๐Ÿš€ Python-{PYTHON_VERSION} torch-{torch.__version__} "
  134. device = str(device).lower()
  135. for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
  136. device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
  137. cpu = device == "cpu"
  138. mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
  139. if cpu or mps:
  140. os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
  141. elif device: # non-cpu device requested
  142. if device == "cuda":
  143. device = "0"
  144. if "," in device:
  145. device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
  146. visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  147. os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
  148. if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
  149. LOGGER.info(s)
  150. install = (
  151. "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
  152. "CUDA devices are seen by torch.\n"
  153. if torch.cuda.device_count() == 0
  154. else ""
  155. )
  156. raise ValueError(
  157. f"Invalid CUDA 'device={device}' requested."
  158. f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
  159. f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
  160. f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
  161. f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
  162. f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
  163. f"{install}"
  164. )
  165. if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
  166. devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
  167. n = len(devices) # device count
  168. if n > 1: # multi-GPU
  169. if batch < 1:
  170. raise ValueError(
  171. "AutoBatch with batch<1 not supported for Multi-GPU training, "
  172. "please specify a valid batch size, i.e. batch=16."
  173. )
  174. if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
  175. raise ValueError(
  176. f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
  177. f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
  178. )
  179. space = " " * (len(s) + 1)
  180. for i, d in enumerate(devices):
  181. s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
  182. arg = "cuda:0"
  183. elif mps and TORCH_2_0 and torch.backends.mps.is_available():
  184. # Prefer MPS if available
  185. s += f"MPS ({get_cpu_info()})\n"
  186. arg = "mps"
  187. else: # revert to CPU
  188. s += f"CPU ({get_cpu_info()})\n"
  189. arg = "cpu"
  190. if arg in {"cpu", "mps"}:
  191. torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
  192. if verbose:
  193. LOGGER.info(s if newline else s.rstrip())
  194. return torch.device(arg)
  195. def time_sync():
  196. """PyTorch-accurate time."""
  197. if torch.cuda.is_available():
  198. torch.cuda.synchronize()
  199. return time.time()
  200. def fuse_conv_and_bn(conv, bn):
  201. """Fuse Conv2d() and BatchNorm2d() layers."""
  202. fusedconv = (
  203. nn.Conv2d(
  204. conv.in_channels,
  205. conv.out_channels,
  206. kernel_size=conv.kernel_size,
  207. stride=conv.stride,
  208. padding=conv.padding,
  209. dilation=conv.dilation,
  210. groups=conv.groups,
  211. bias=True,
  212. )
  213. .requires_grad_(False)
  214. .to(conv.weight.device)
  215. )
  216. # Prepare filters
  217. w_conv = conv.weight.view(conv.out_channels, -1)
  218. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  219. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  220. # Prepare spatial bias
  221. b_conv = (
  222. torch.zeros(conv.weight.shape[0], dtype=conv.weight.dtype, device=conv.weight.device)
  223. if conv.bias is None
  224. else conv.bias
  225. )
  226. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  227. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  228. return fusedconv
  229. def fuse_deconv_and_bn(deconv, bn):
  230. """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
  231. fuseddconv = (
  232. nn.ConvTranspose2d(
  233. deconv.in_channels,
  234. deconv.out_channels,
  235. kernel_size=deconv.kernel_size,
  236. stride=deconv.stride,
  237. padding=deconv.padding,
  238. output_padding=deconv.output_padding,
  239. dilation=deconv.dilation,
  240. groups=deconv.groups,
  241. bias=True,
  242. )
  243. .requires_grad_(False)
  244. .to(deconv.weight.device)
  245. )
  246. # Prepare filters
  247. w_deconv = deconv.weight.view(deconv.out_channels, -1)
  248. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  249. fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
  250. # Prepare spatial bias
  251. b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
  252. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  253. fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  254. return fuseddconv
  255. def model_info(model, detailed=False, verbose=True, imgsz=640):
  256. """
  257. Print and return detailed model information layer by layer.
  258. Args:
  259. model (nn.Module): Model to analyze.
  260. detailed (bool, optional): Whether to print detailed layer information. Defaults to False.
  261. verbose (bool, optional): Whether to print model information. Defaults to True.
  262. imgsz (int | List, optional): Input image size. Defaults to 640.
  263. Returns:
  264. (Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs.
  265. """
  266. if not verbose:
  267. return
  268. n_p = get_num_params(model) # number of parameters
  269. n_g = get_num_gradients(model) # number of gradients
  270. layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)
  271. n_l = len(layers) # number of layers
  272. if detailed:
  273. h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}"
  274. LOGGER.info(h)
  275. for i, (mn, m) in enumerate(layers.items()):
  276. mn = mn.replace("module_list.", "")
  277. mt = m.__class__.__name__
  278. if len(m._parameters):
  279. for pn, p in m.named_parameters():
  280. LOGGER.info(
  281. f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
  282. )
  283. else: # layers with no learnable params
  284. LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")
  285. flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
  286. fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
  287. fs = f", {flops:.1f} GFLOPs" if flops else ""
  288. yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
  289. model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
  290. LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
  291. return n_l, n_p, n_g, flops
  292. def get_num_params(model):
  293. """Return the total number of parameters in a YOLO model."""
  294. return sum(x.numel() for x in model.parameters())
  295. def get_num_gradients(model):
  296. """Return the total number of parameters with gradients in a YOLO model."""
  297. return sum(x.numel() for x in model.parameters() if x.requires_grad)
  298. def model_info_for_loggers(trainer):
  299. """
  300. Return model info dict with useful model information.
  301. Args:
  302. trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
  303. Returns:
  304. (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.
  305. Examples:
  306. YOLOv8n info for loggers
  307. >>> results = {
  308. ... "model/parameters": 3151904,
  309. ... "model/GFLOPs": 8.746,
  310. ... "model/speed_ONNX(ms)": 41.244,
  311. ... "model/speed_TensorRT(ms)": 3.211,
  312. ... "model/speed_PyTorch(ms)": 18.755,
  313. ...}
  314. """
  315. if trainer.args.profile: # profile ONNX and TensorRT times
  316. from ultralytics.utils.benchmarks import ProfileModels
  317. results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
  318. results.pop("model/name")
  319. else: # only return PyTorch times from most recent validation
  320. results = {
  321. "model/parameters": get_num_params(trainer.model),
  322. "model/GFLOPs": round(get_flops(trainer.model), 3),
  323. }
  324. results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
  325. return results
  326. def get_flops(model, imgsz=640):
  327. """
  328. Calculate FLOPs (floating point operations) for a model in billions.
  329. Attempts two calculation methods: first with a stride-based tensor for efficiency,
  330. then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0
  331. if thop library is unavailable or calculation fails.
  332. Args:
  333. model (nn.Module): The model to calculate FLOPs for.
  334. imgsz (int | List[int], optional): Input image size. Defaults to 640.
  335. Returns:
  336. (float): The model FLOPs in billions.
  337. """
  338. if not thop:
  339. return 0.0 # if not installed return 0.0 GFLOPs
  340. try:
  341. model = de_parallel(model)
  342. p = next(model.parameters())
  343. if not isinstance(imgsz, list):
  344. imgsz = [imgsz, imgsz] # expand if int/float
  345. try:
  346. # Method 1: Use stride-based input tensor
  347. stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
  348. im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
  349. flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
  350. return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
  351. except Exception:
  352. # Method 2: Use actual image size (required for RTDETR models)
  353. im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
  354. return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
  355. except Exception:
  356. return 0.0
  357. def get_flops_with_torch_profiler(model, imgsz=640):
  358. """
  359. Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
  360. Args:
  361. model (nn.Module): The model to calculate FLOPs for.
  362. imgsz (int | List[int], optional): Input image size. Defaults to 640.
  363. Returns:
  364. (float): The model's FLOPs in billions.
  365. """
  366. if not TORCH_2_0: # torch profiler implemented in torch>=2.0
  367. return 0.0
  368. model = de_parallel(model)
  369. p = next(model.parameters())
  370. if not isinstance(imgsz, list):
  371. imgsz = [imgsz, imgsz] # expand if int/float
  372. try:
  373. # Use stride size for input tensor
  374. stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
  375. im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
  376. with torch.profiler.profile(with_flops=True) as prof:
  377. model(im)
  378. flops = sum(x.flops for x in prof.key_averages()) / 1e9
  379. flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
  380. except Exception:
  381. # Use actual image size for input tensor (i.e. required for RTDETR models)
  382. im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
  383. with torch.profiler.profile(with_flops=True) as prof:
  384. model(im)
  385. flops = sum(x.flops for x in prof.key_averages()) / 1e9
  386. return flops
  387. def initialize_weights(model):
  388. """Initialize model weights to random values."""
  389. for m in model.modules():
  390. t = type(m)
  391. if t is nn.Conv2d:
  392. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  393. elif t is nn.BatchNorm2d:
  394. m.eps = 1e-3
  395. m.momentum = 0.03
  396. elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
  397. m.inplace = True
  398. def scale_img(img, ratio=1.0, same_shape=False, gs=32):
  399. """
  400. Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
  401. Args:
  402. img (torch.Tensor): Input image tensor.
  403. ratio (float, optional): Scaling ratio. Defaults to 1.0.
  404. same_shape (bool, optional): Whether to maintain the same shape. Defaults to False.
  405. gs (int, optional): Grid size for padding. Defaults to 32.
  406. Returns:
  407. (torch.Tensor): Scaled and padded image tensor.
  408. """
  409. if ratio == 1.0:
  410. return img
  411. h, w = img.shape[2:]
  412. s = (int(h * ratio), int(w * ratio)) # new size
  413. img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
  414. if not same_shape: # pad/crop img
  415. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  416. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  417. def copy_attr(a, b, include=(), exclude=()):
  418. """
  419. Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
  420. Args:
  421. a (object): Destination object to copy attributes to.
  422. b (object): Source object to copy attributes from.
  423. include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to ().
  424. exclude (tuple, optional): Attributes to exclude. Defaults to ().
  425. """
  426. for k, v in b.__dict__.items():
  427. if (len(include) and k not in include) or k.startswith("_") or k in exclude:
  428. continue
  429. else:
  430. setattr(a, k, v)
  431. def get_latest_opset():
  432. """
  433. Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.
  434. Returns:
  435. (int): The ONNX opset version.
  436. """
  437. if TORCH_1_13:
  438. # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
  439. return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
  440. # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
  441. version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
  442. return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
  443. def intersect_dicts(da, db, exclude=()):
  444. """
  445. Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
  446. Args:
  447. da (dict): First dictionary.
  448. db (dict): Second dictionary.
  449. exclude (tuple, optional): Keys to exclude. Defaults to ().
  450. Returns:
  451. (dict): Dictionary of intersecting keys with matching shapes.
  452. """
  453. return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
  454. def is_parallel(model):
  455. """
  456. Returns True if model is of type DP or DDP.
  457. Args:
  458. model (nn.Module): Model to check.
  459. Returns:
  460. (bool): True if model is DataParallel or DistributedDataParallel.
  461. """
  462. return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
  463. def de_parallel(model):
  464. """
  465. De-parallelize a model: returns single-GPU model if model is of type DP or DDP.
  466. Args:
  467. model (nn.Module): Model to de-parallelize.
  468. Returns:
  469. (nn.Module): De-parallelized model.
  470. """
  471. return model.module if is_parallel(model) else model
  472. def one_cycle(y1=0.0, y2=1.0, steps=100):
  473. """
  474. Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
  475. Args:
  476. y1 (float, optional): Initial value. Defaults to 0.0.
  477. y2 (float, optional): Final value. Defaults to 1.0.
  478. steps (int, optional): Number of steps. Defaults to 100.
  479. Returns:
  480. (function): Lambda function for computing the sinusoidal ramp.
  481. """
  482. return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
  483. def init_seeds(seed=0, deterministic=False):
  484. """
  485. Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
  486. Args:
  487. seed (int, optional): Random seed. Defaults to 0.
  488. deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False.
  489. """
  490. random.seed(seed)
  491. np.random.seed(seed)
  492. torch.manual_seed(seed)
  493. torch.cuda.manual_seed(seed)
  494. torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
  495. # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
  496. if deterministic:
  497. if TORCH_2_0:
  498. torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
  499. torch.backends.cudnn.deterministic = True
  500. os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
  501. os.environ["PYTHONHASHSEED"] = str(seed)
  502. else:
  503. LOGGER.warning("Upgrade to torch>=2.0.0 for deterministic training.")
  504. else:
  505. unset_deterministic()
  506. def unset_deterministic():
  507. """Unsets all the configurations applied for deterministic training."""
  508. torch.use_deterministic_algorithms(False)
  509. torch.backends.cudnn.deterministic = False
  510. os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
  511. os.environ.pop("PYTHONHASHSEED", None)
  512. class ModelEMA:
  513. """
  514. Updated Exponential Moving Average (EMA) implementation.
  515. Keeps a moving average of everything in the model state_dict (parameters and buffers).
  516. For EMA details see References.
  517. To disable EMA set the `enabled` attribute to `False`.
  518. Attributes:
  519. ema (nn.Module): Copy of the model in evaluation mode.
  520. updates (int): Number of EMA updates.
  521. decay (function): Decay function that determines the EMA weight.
  522. enabled (bool): Whether EMA is enabled.
  523. References:
  524. - https://github.com/rwightman/pytorch-image-models
  525. - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  526. """
  527. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  528. """
  529. Initialize EMA for 'model' with given arguments.
  530. Args:
  531. model (nn.Module): Model to create EMA for.
  532. decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999.
  533. tau (int, optional): EMA decay time constant. Defaults to 2000.
  534. updates (int, optional): Initial number of updates. Defaults to 0.
  535. """
  536. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  537. self.updates = updates # number of EMA updates
  538. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  539. for p in self.ema.parameters():
  540. p.requires_grad_(False)
  541. self.enabled = True
  542. def update(self, model):
  543. """
  544. Update EMA parameters.
  545. Args:
  546. model (nn.Module): Model to update EMA from.
  547. """
  548. if self.enabled:
  549. self.updates += 1
  550. d = self.decay(self.updates)
  551. msd = de_parallel(model).state_dict() # model state_dict
  552. for k, v in self.ema.state_dict().items():
  553. if v.dtype.is_floating_point: # true for FP16 and FP32
  554. v *= d
  555. v += (1 - d) * msd[k].detach()
  556. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
  557. def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
  558. """
  559. Updates attributes and saves stripped model with optimizer removed.
  560. Args:
  561. model (nn.Module): Model to update attributes from.
  562. include (tuple, optional): Attributes to include. Defaults to ().
  563. exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer").
  564. """
  565. if self.enabled:
  566. copy_attr(self.ema, model, include, exclude)
  567. def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
  568. """
  569. Strip optimizer from 'f' to finalize training, optionally save as 's'.
  570. Args:
  571. f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'.
  572. s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
  573. updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
  574. Returns:
  575. (dict): The combined checkpoint dictionary.
  576. Examples:
  577. >>> from pathlib import Path
  578. >>> from ultralytics.utils.torch_utils import strip_optimizer
  579. >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
  580. >>> strip_optimizer(f)
  581. """
  582. try:
  583. x = torch.load(f, map_location=torch.device("cpu"))
  584. assert isinstance(x, dict), "checkpoint is not a Python dictionary"
  585. assert "model" in x, "'model' missing from checkpoint"
  586. except Exception as e:
  587. LOGGER.warning(f"Skipping {f}, not a valid Ultralytics model: {e}")
  588. return {}
  589. metadata = {
  590. "date": datetime.now().isoformat(),
  591. "version": __version__,
  592. "license": "AGPL-3.0 License (https://ultralytics.com/license)",
  593. "docs": "https://docs.ultralytics.com",
  594. }
  595. # Update model
  596. if x.get("ema"):
  597. x["model"] = x["ema"] # replace model with EMA
  598. if hasattr(x["model"], "args"):
  599. x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
  600. if hasattr(x["model"], "criterion"):
  601. x["model"].criterion = None # strip loss criterion
  602. x["model"].half() # to FP16
  603. for p in x["model"].parameters():
  604. p.requires_grad = False
  605. # Update other keys
  606. args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
  607. for k in "optimizer", "best_fitness", "ema", "updates": # keys
  608. x[k] = None
  609. x["epoch"] = -1
  610. x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
  611. # x['model'].args = x['train_args']
  612. # Save
  613. combined = {**metadata, **x, **(updates or {})}
  614. torch.save(combined, s or f) # combine dicts (prefer to the right)
  615. mb = os.path.getsize(s or f) / 1e6 # file size
  616. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  617. return combined
  618. def convert_optimizer_state_dict_to_fp16(state_dict):
  619. """
  620. Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
  621. Args:
  622. state_dict (dict): Optimizer state dictionary.
  623. Returns:
  624. (dict): Converted optimizer state dictionary with FP16 tensors.
  625. """
  626. for state in state_dict["state"].values():
  627. for k, v in state.items():
  628. if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
  629. state[k] = v.half()
  630. return state_dict
  631. @contextmanager
  632. def cuda_memory_usage(device=None):
  633. """
  634. Monitor and manage CUDA memory usage.
  635. This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
  636. It then yields a dictionary containing memory usage information, which can be updated by the caller.
  637. Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
  638. Args:
  639. device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
  640. Yields:
  641. (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
  642. """
  643. cuda_info = dict(memory=0)
  644. if torch.cuda.is_available():
  645. torch.cuda.empty_cache()
  646. try:
  647. yield cuda_info
  648. finally:
  649. cuda_info["memory"] = torch.cuda.memory_reserved(device)
  650. else:
  651. yield cuda_info
  652. def profile(input, ops, n=10, device=None, max_num_obj=0):
  653. """
  654. Ultralytics speed, memory and FLOPs profiler.
  655. Args:
  656. input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile.
  657. ops (nn.Module | List[nn.Module]): Model or list of operations to profile.
  658. n (int, optional): Number of iterations to average. Defaults to 10.
  659. device (str | torch.device, optional): Device to profile on. Defaults to None.
  660. max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0.
  661. Returns:
  662. (list): Profile results for each operation.
  663. Examples:
  664. >>> from ultralytics.utils.torch_utils import profile
  665. >>> input = torch.randn(16, 3, 640, 640)
  666. >>> m1 = lambda x: x * torch.sigmoid(x)
  667. >>> m2 = nn.SiLU()
  668. >>> profile(input, [m1, m2], n=100) # profile over 100 iterations
  669. """
  670. results = []
  671. if not isinstance(device, torch.device):
  672. device = select_device(device)
  673. LOGGER.info(
  674. f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  675. f"{'input':>24s}{'output':>24s}"
  676. )
  677. gc.collect() # attempt to free unused memory
  678. torch.cuda.empty_cache()
  679. for x in input if isinstance(input, list) else [input]:
  680. x = x.to(device)
  681. x.requires_grad = True
  682. for m in ops if isinstance(ops, list) else [ops]:
  683. m = m.to(device) if hasattr(m, "to") else m # device
  684. m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  685. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  686. try:
  687. flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
  688. except Exception:
  689. flops = 0
  690. try:
  691. mem = 0
  692. for _ in range(n):
  693. with cuda_memory_usage(device) as cuda_info:
  694. t[0] = time_sync()
  695. y = m(x)
  696. t[1] = time_sync()
  697. try:
  698. (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  699. t[2] = time_sync()
  700. except Exception: # no backward method
  701. # print(e) # for debug
  702. t[2] = float("nan")
  703. mem += cuda_info["memory"] / 1e9 # (GB)
  704. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  705. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  706. if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
  707. with cuda_memory_usage(device) as cuda_info:
  708. torch.randn(
  709. x.shape[0],
  710. max_num_obj,
  711. int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
  712. device=device,
  713. dtype=torch.float32,
  714. )
  715. mem += cuda_info["memory"] / 1e9 # (GB)
  716. s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
  717. p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
  718. LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
  719. results.append([p, flops, mem, tf, tb, s_in, s_out])
  720. except Exception as e:
  721. LOGGER.info(e)
  722. results.append(None)
  723. finally:
  724. gc.collect() # attempt to free unused memory
  725. torch.cuda.empty_cache()
  726. return results
  727. class EarlyStopping:
  728. """
  729. Early stopping class that stops training when a specified number of epochs have passed without improvement.
  730. Attributes:
  731. best_fitness (float): Best fitness value observed.
  732. best_epoch (int): Epoch where best fitness was observed.
  733. patience (int): Number of epochs to wait after fitness stops improving before stopping.
  734. possible_stop (bool): Flag indicating if stopping may occur next epoch.
  735. """
  736. def __init__(self, patience=50):
  737. """
  738. Initialize early stopping object.
  739. Args:
  740. patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
  741. """
  742. self.best_fitness = 0.0 # i.e. mAP
  743. self.best_epoch = 0
  744. self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
  745. self.possible_stop = False # possible stop may occur next epoch
  746. def __call__(self, epoch, fitness):
  747. """
  748. Check whether to stop training.
  749. Args:
  750. epoch (int): Current epoch of training
  751. fitness (float): Fitness value of current epoch
  752. Returns:
  753. (bool): True if training should stop, False otherwise
  754. """
  755. if fitness is None: # check if fitness=None (happens when val=False)
  756. return False
  757. if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training
  758. self.best_epoch = epoch
  759. self.best_fitness = fitness
  760. delta = epoch - self.best_epoch # epochs without improvement
  761. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  762. stop = delta >= self.patience # stop training if patience exceeded
  763. if stop:
  764. prefix = colorstr("EarlyStopping: ")
  765. LOGGER.info(
  766. f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
  767. f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
  768. f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
  769. f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
  770. )
  771. return stop
  772. class FXModel(nn.Module):
  773. """
  774. A custom model class for torch.fx compatibility.
  775. This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
  776. manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
  777. copying.
  778. Attributes:
  779. model (nn.Module): The original model's layers.
  780. """
  781. def __init__(self, model):
  782. """
  783. Initialize the FXModel.
  784. Args:
  785. model (nn.Module): The original model to wrap for torch.fx compatibility.
  786. """
  787. super().__init__()
  788. copy_attr(self, model)
  789. # Explicitly set `model` since `copy_attr` somehow does not copy it.
  790. self.model = model.model
  791. def forward(self, x):
  792. """
  793. Forward pass through the model.
  794. This method performs the forward pass through the model, handling the dependencies between layers and saving
  795. intermediate outputs.
  796. Args:
  797. x (torch.Tensor): The input tensor to the model.
  798. Returns:
  799. (torch.Tensor): The output tensor from the model.
  800. """
  801. y = [] # outputs
  802. for m in self.model:
  803. if m.f != -1: # if not from previous layer
  804. # from earlier layers
  805. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
  806. x = m(x) # run
  807. y.append(x) # save output
  808. return x
Discard
Tip!

Press p or to see the previous file or, n or to see the next file