Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolov8_sahi.py 5.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import argparse
  3. import cv2
  4. from sahi import AutoDetectionModel
  5. from sahi.predict import get_sliced_prediction
  6. from sahi.utils.ultralytics import download_model_weights
  7. from ultralytics.utils.files import increment_path
  8. class SAHIInference:
  9. """
  10. Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.
  11. This class integrates SAHI (Slicing Aided Hyper Inference) with YOLO11 models to perform efficient object detection
  12. on large images by slicing them into smaller pieces, running inference on each slice, and then merging the results.
  13. Attributes:
  14. detection_model (AutoDetectionModel): The loaded YOLO11 model wrapped with SAHI functionality.
  15. Methods:
  16. load_model: Load a YOLO11 model with specified weights for object detection using SAHI.
  17. inference: Run object detection on a video using YOLO11 and SAHI.
  18. parse_opt: Parse command line arguments for the inference process.
  19. Examples:
  20. Initialize and run SAHI inference on a video
  21. >>> sahi_inference = SAHIInference()
  22. >>> sahi_inference.inference(weights="yolo11n.pt", source="video.mp4", view_img=True)
  23. """
  24. def __init__(self):
  25. """Initialize the SAHIInference class for performing sliced inference using SAHI with YOLO11 models."""
  26. self.detection_model = None
  27. def load_model(self, weights: str, device: str) -> None:
  28. """
  29. Load a YOLO11 model with specified weights for object detection using SAHI.
  30. Args:
  31. weights (str): Path to the model weights file.
  32. device (str): CUDA device, i.e., '0' or '0,1,2,3' or 'cpu'.
  33. """
  34. from ultralytics.utils.torch_utils import select_device
  35. yolo11_model_path = f"models/{weights}"
  36. download_model_weights(yolo11_model_path) # Download model if not present
  37. self.detection_model = AutoDetectionModel.from_pretrained(
  38. model_type="ultralytics", model_path=yolo11_model_path, device=select_device(device)
  39. )
  40. def inference(
  41. self,
  42. weights: str = "yolo11n.pt",
  43. source: str = "test.mp4",
  44. view_img: bool = False,
  45. save_img: bool = False,
  46. exist_ok: bool = False,
  47. device: str = "",
  48. hide_conf: bool = False,
  49. slice_width: int = 512,
  50. slice_height: int = 512,
  51. ) -> None:
  52. """
  53. Run object detection on a video using YOLO11 and SAHI.
  54. The function processes each frame of the video, applies sliced inference using SAHI,
  55. and optionally displays and/or saves the results with bounding boxes and labels.
  56. Args:
  57. weights (str): Model weights' path.
  58. source (str): Video file path.
  59. view_img (bool): Whether to display results in a window.
  60. save_img (bool): Whether to save results to a video file.
  61. exist_ok (bool): Whether to overwrite existing output files.
  62. device (str, optional): CUDA device, i.e., '0' or '0,1,2,3' or 'cpu'.
  63. hide_conf (bool, optional): Flag to show or hide confidences in the output.
  64. slice_width (int, optional): Slice width for inference.
  65. slice_height (int, optional): Slice height for inference.
  66. """
  67. # Video setup
  68. cap = cv2.VideoCapture(source)
  69. assert cap.isOpened(), "Error reading video file"
  70. # Output setup
  71. save_dir = increment_path("runs/detect/predict", exist_ok)
  72. save_dir.mkdir(parents=True, exist_ok=True)
  73. # Load model
  74. self.load_model(weights, device)
  75. idx = 0 # Index for image frame writing
  76. while cap.isOpened():
  77. success, frame = cap.read()
  78. if not success:
  79. break
  80. # Perform sliced prediction using SAHI
  81. results = get_sliced_prediction(
  82. frame[..., ::-1], # Convert BGR to RGB
  83. self.detection_model,
  84. slice_height=slice_height,
  85. slice_width=slice_width,
  86. )
  87. # Display results if requested
  88. if view_img:
  89. cv2.imshow("Ultralytics YOLO Inference", frame)
  90. # Save results if requested
  91. if save_img:
  92. idx += 1
  93. results.export_visuals(export_dir=save_dir, file_name=f"img_{idx}", hide_conf=hide_conf)
  94. # Break loop if 'q' is pressed
  95. if cv2.waitKey(1) & 0xFF == ord("q"):
  96. break
  97. # Clean up resources
  98. cap.release()
  99. cv2.destroyAllWindows()
  100. @staticmethod
  101. def parse_opt() -> argparse.Namespace:
  102. """
  103. Parse command line arguments for the inference process.
  104. Returns:
  105. (argparse.Namespace): Parsed command line arguments.
  106. """
  107. parser = argparse.ArgumentParser()
  108. parser.add_argument("--weights", type=str, default="yolo11n.pt", help="initial weights path")
  109. parser.add_argument("--source", type=str, required=True, help="video file path")
  110. parser.add_argument("--view-img", action="store_true", help="show results")
  111. parser.add_argument("--save-img", action="store_true", help="save results")
  112. parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
  113. parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
  114. parser.add_argument("--hide-conf", default=False, action="store_true", help="display or hide confidences")
  115. parser.add_argument("--slice-width", default=512, type=int, help="Slice width for inference")
  116. parser.add_argument("--slice-height", default=512, type=int, help="Slice height for inference")
  117. return parser.parse_args()
  118. if __name__ == "__main__":
  119. inference = SAHIInference()
  120. inference.inference(**vars(inference.parse_opt()))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...