Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolov8_sahi.py 4.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  1. import argparse
  2. from pathlib import Path
  3. import cv2
  4. from sahi import AutoDetectionModel
  5. from sahi.predict import get_sliced_prediction
  6. from sahi.utils.yolov8 import download_yolov8s_model
  7. from ultralytics.utils.files import increment_path
  8. def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False):
  9. """
  10. Run object detection on a video using YOLOv8 and SAHI.
  11. Args:
  12. weights (str): Model weights path.
  13. source (str): Video file path.
  14. view_img (bool): Show results.
  15. save_img (bool): Save results.
  16. exist_ok (bool): Overwrite existing files.
  17. """
  18. # Check source path
  19. if not Path(source).exists():
  20. raise FileNotFoundError(f"Source path '{source}' does not exist.")
  21. yolov8_model_path = f'models/{weights}'
  22. download_yolov8s_model(yolov8_model_path)
  23. detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8',
  24. model_path=yolov8_model_path,
  25. confidence_threshold=0.3,
  26. device='cpu')
  27. # Video setup
  28. videocapture = cv2.VideoCapture(source)
  29. frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
  30. fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
  31. # Output setup
  32. save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok)
  33. save_dir.mkdir(parents=True, exist_ok=True)
  34. video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
  35. while videocapture.isOpened():
  36. success, frame = videocapture.read()
  37. if not success:
  38. break
  39. results = get_sliced_prediction(frame,
  40. detection_model,
  41. slice_height=512,
  42. slice_width=512,
  43. overlap_height_ratio=0.2,
  44. overlap_width_ratio=0.2)
  45. object_prediction_list = results.object_prediction_list
  46. boxes_list = []
  47. clss_list = []
  48. for ind, _ in enumerate(object_prediction_list):
  49. boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \
  50. object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy
  51. clss = object_prediction_list[ind].category.name
  52. boxes_list.append(boxes)
  53. clss_list.append(clss)
  54. for box, cls in zip(boxes_list, clss_list):
  55. x1, y1, x2, y2 = box
  56. cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
  57. label = str(cls)
  58. t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
  59. cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255),
  60. -1)
  61. cv2.putText(frame,
  62. label, (int(x1), int(y1) - 2),
  63. 0,
  64. 0.6, [255, 255, 255],
  65. thickness=1,
  66. lineType=cv2.LINE_AA)
  67. if view_img:
  68. cv2.imshow(Path(source).stem, frame)
  69. if save_img:
  70. video_writer.write(frame)
  71. if cv2.waitKey(1) & 0xFF == ord('q'):
  72. break
  73. video_writer.release()
  74. videocapture.release()
  75. cv2.destroyAllWindows()
  76. def parse_opt():
  77. """Parse command line arguments."""
  78. parser = argparse.ArgumentParser()
  79. parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
  80. parser.add_argument('--source', type=str, required=True, help='video file path')
  81. parser.add_argument('--view-img', action='store_true', help='show results')
  82. parser.add_argument('--save-img', action='store_true', help='save results')
  83. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  84. return parser.parse_args()
  85. def main(opt):
  86. """Main function."""
  87. run(**vars(opt))
  88. if __name__ == '__main__':
  89. opt = parse_opt()
  90. main(opt)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...