Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  1. # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
  2. """Monkey patches to update/extend functionality of existing functions."""
  3. import time
  4. from pathlib import Path
  5. import cv2
  6. import numpy as np
  7. import torch
  8. # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
  9. _imshow = cv2.imshow # copy to avoid recursion errors
  10. def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
  11. """
  12. Read an image from a file.
  13. Args:
  14. filename (str): Path to the file to read.
  15. flags (int): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
  16. Returns:
  17. (np.ndarray): The read image.
  18. Examples:
  19. >>> img = imread("path/to/image.jpg")
  20. >>> img = imread("path/to/image.jpg", cv2.IMREAD_GRAYSCALE)
  21. """
  22. file_bytes = np.fromfile(filename, np.uint8)
  23. if filename.endswith((".tiff", ".tif")):
  24. success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
  25. if success:
  26. # handle RGB images in tif/tiff format
  27. return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
  28. return None
  29. else:
  30. return cv2.imdecode(file_bytes, flags)
  31. def imwrite(filename: str, img: np.ndarray, params=None):
  32. """
  33. Write an image to a file.
  34. Args:
  35. filename (str): Path to the file to write.
  36. img (np.ndarray): Image to write.
  37. params (List[int], optional): Additional parameters for image encoding.
  38. Returns:
  39. (bool): True if the file was written successfully, False otherwise.
  40. Examples:
  41. >>> import numpy as np
  42. >>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image
  43. >>> success = imwrite("output.jpg", img) # Write image to file
  44. >>> print(success)
  45. True
  46. """
  47. try:
  48. cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
  49. return True
  50. except Exception:
  51. return False
  52. def imshow(winname: str, mat: np.ndarray):
  53. """
  54. Display an image in the specified window.
  55. This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It is
  56. particularly useful for visualizing images during development and debugging.
  57. Args:
  58. winname (str): Name of the window where the image will be displayed. If a window with this name already
  59. exists, the image will be displayed in that window.
  60. mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
  61. Examples:
  62. >>> import numpy as np
  63. >>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image
  64. >>> img[:100, :100] = [255, 0, 0] # Add a blue square
  65. >>> imshow("Example Window", img) # Display the image
  66. """
  67. _imshow(winname.encode("unicode_escape").decode(), mat)
  68. # PyTorch functions ----------------------------------------------------------------------------------------------------
  69. _torch_load = torch.load # copy to avoid recursion errors
  70. _torch_save = torch.save
  71. def torch_load(*args, **kwargs):
  72. """
  73. Load a PyTorch model with updated arguments to avoid warnings.
  74. This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
  75. Args:
  76. *args (Any): Variable length argument list to pass to torch.load.
  77. **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
  78. Returns:
  79. (Any): The loaded PyTorch object.
  80. Notes:
  81. For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
  82. if the argument is not provided, to avoid deprecation warnings.
  83. """
  84. from ultralytics.utils.torch_utils import TORCH_1_13
  85. if TORCH_1_13 and "weights_only" not in kwargs:
  86. kwargs["weights_only"] = False
  87. return _torch_load(*args, **kwargs)
  88. def torch_save(*args, **kwargs):
  89. """
  90. Save PyTorch objects with retry mechanism for robustness.
  91. This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
  92. due to device flushing delays or antivirus scanning.
  93. Args:
  94. *args (Any): Positional arguments to pass to torch.save.
  95. **kwargs (Any): Keyword arguments to pass to torch.save.
  96. Returns:
  97. (Any): Result of torch.save operation if successful, None otherwise.
  98. Examples:
  99. >>> model = torch.nn.Linear(10, 1)
  100. >>> torch_save(model.state_dict(), "model.pt")
  101. """
  102. for i in range(4): # 3 retries
  103. try:
  104. return _torch_save(*args, **kwargs)
  105. except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
  106. if i == 3:
  107. raise e
  108. time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s
Discard
Tip!

Press p or to see the previous file or, n or to see the next file