Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolov8_sahi.py 3.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import argparse
  3. from pathlib import Path
  4. import cv2
  5. from sahi import AutoDetectionModel
  6. from sahi.predict import get_sliced_prediction
  7. from sahi.utils.ultralytics import download_yolo11n_model
  8. from ultralytics.utils.files import increment_path
  9. from ultralytics.utils.plotting import Annotator, colors
  10. class SAHIInference:
  11. """Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results."""
  12. def __init__(self):
  13. """Initializes the SAHIInference class for performing sliced inference using SAHI with YOLO11 models."""
  14. self.detection_model = None
  15. def load_model(self, weights):
  16. """Loads a YOLO11 model with specified weights for object detection using SAHI."""
  17. yolo11_model_path = f"models/{weights}"
  18. download_yolo11n_model(yolo11_model_path)
  19. self.detection_model = AutoDetectionModel.from_pretrained(
  20. model_type="ultralytics", model_path=yolo11_model_path, device="cpu"
  21. )
  22. def inference(
  23. self,
  24. weights="yolo11n.pt",
  25. source="test.mp4",
  26. view_img=False,
  27. save_img=False,
  28. exist_ok=False,
  29. ):
  30. """
  31. Run object detection on a video using YOLO11 and SAHI.
  32. Args:
  33. weights (str): Model weights path.
  34. source (str): Video file path.
  35. view_img (bool): Show results.
  36. save_img (bool): Save results.
  37. exist_ok (bool): Overwrite existing files.
  38. """
  39. # Video setup
  40. cap = cv2.VideoCapture(source)
  41. assert cap.isOpened(), "Error reading video file"
  42. frame_width, frame_height = int(cap.get(3)), int(cap.get(4))
  43. # Output setup
  44. save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
  45. save_dir.mkdir(parents=True, exist_ok=True)
  46. video_writer = cv2.VideoWriter(
  47. str(save_dir / f"{Path(source).stem}.avi"),
  48. cv2.VideoWriter_fourcc(*"MJPG"),
  49. int(cap.get(5)),
  50. (frame_width, frame_height),
  51. )
  52. # Load model
  53. self.load_model(weights)
  54. while cap.isOpened():
  55. success, frame = cap.read()
  56. if not success:
  57. break
  58. annotator = Annotator(frame) # Initialize annotator for plotting detection and tracking results
  59. results = get_sliced_prediction(
  60. frame[..., ::-1],
  61. self.detection_model,
  62. slice_height=512,
  63. slice_width=512,
  64. )
  65. detection_data = [
  66. (det.category.name, det.category.id, (det.bbox.minx, det.bbox.miny, det.bbox.maxx, det.bbox.maxy))
  67. for det in results.object_prediction_list
  68. ]
  69. for det in detection_data:
  70. annotator.box_label(det[2], label=str(det[0]), color=colors(int(det[1]), True))
  71. if view_img:
  72. cv2.imshow(Path(source).stem, frame)
  73. if save_img:
  74. video_writer.write(frame)
  75. if cv2.waitKey(1) & 0xFF == ord("q"):
  76. break
  77. video_writer.release()
  78. cap.release()
  79. cv2.destroyAllWindows()
  80. def parse_opt(self):
  81. """Parse command line arguments."""
  82. parser = argparse.ArgumentParser()
  83. parser.add_argument("--weights", type=str, default="yolo11n.pt", help="initial weights path")
  84. parser.add_argument("--source", type=str, required=True, help="video file path")
  85. parser.add_argument("--view-img", action="store_true", help="show results")
  86. parser.add_argument("--save-img", action="store_true", help="save results")
  87. parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
  88. return parser.parse_args()
  89. if __name__ == "__main__":
  90. inference = SAHIInference()
  91. inference.inference(**vars(inference.parse_opt()))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...