Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_solutions.py 5.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import cv2
  3. import pytest
  4. from tests import TMP
  5. from ultralytics import YOLO, solutions
  6. from ultralytics.utils import ASSETS_URL, WEIGHTS_DIR
  7. from ultralytics.utils.downloads import safe_download
  8. DEMO_VIDEO = "solutions_ci_demo.mp4" # for all the solutions, except workout and parking
  9. POSE_VIDEO = "solution_ci_pose_demo.mp4" # only for workouts monitoring solution
  10. PARKING_VIDEO = "solution_ci_parking_demo.mp4" # only for parking management solution
  11. PARKING_AREAS_JSON = "solution_ci_parking_areas.json" # only for parking management solution
  12. PARKING_MODEL = "solutions_ci_parking_model.pt" # only for parking management solution
  13. @pytest.mark.slow
  14. def test_major_solutions():
  15. """Test the object counting, heatmap, speed estimation, trackzone and queue management solution."""
  16. safe_download(url=f"{ASSETS_URL}/{DEMO_VIDEO}", dir=TMP)
  17. cap = cv2.VideoCapture(str(TMP / DEMO_VIDEO))
  18. assert cap.isOpened(), "Error reading video file"
  19. region_points = [(20, 400), (1080, 400), (1080, 360), (20, 360)]
  20. counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter
  21. heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps
  22. heatmap_count = solutions.Heatmap(
  23. colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False, region=region_points
  24. ) # Test heatmaps with object counting
  25. speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager
  26. queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation
  27. line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics
  28. pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics
  29. bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics
  30. area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics
  31. trackzone = solutions.TrackZone(region=region_points, model="yolo11n.pt", show=False) # Test trackzone
  32. frame_count = 0 # Required for analytics
  33. while cap.isOpened():
  34. success, im0 = cap.read()
  35. if not success:
  36. break
  37. frame_count += 1
  38. original_im0 = im0.copy()
  39. _ = counter.count(original_im0.copy())
  40. _ = heatmap.generate_heatmap(original_im0.copy())
  41. _ = heatmap_count.generate_heatmap(original_im0.copy())
  42. _ = speed.estimate_speed(original_im0.copy())
  43. _ = queue.process_queue(original_im0.copy())
  44. _ = line_analytics.process_data(original_im0.copy(), frame_count)
  45. _ = pie_analytics.process_data(original_im0.copy(), frame_count)
  46. _ = bar_analytics.process_data(original_im0.copy(), frame_count)
  47. _ = area_analytics.process_data(original_im0.copy(), frame_count)
  48. _ = trackzone.trackzone(original_im0.copy())
  49. cap.release()
  50. # Test workouts monitoring
  51. safe_download(url=f"{ASSETS_URL}/{POSE_VIDEO}", dir=TMP)
  52. cap = cv2.VideoCapture(str(TMP / POSE_VIDEO))
  53. assert cap.isOpened(), "Error reading video file"
  54. gym = solutions.AIGym(kpts=[5, 11, 13], show=False)
  55. while cap.isOpened():
  56. success, im0 = cap.read()
  57. if not success:
  58. break
  59. _ = gym.monitor(im0)
  60. cap.release()
  61. # Test parking management
  62. safe_download(url=f"{ASSETS_URL}/{PARKING_VIDEO}", dir=TMP)
  63. safe_download(url=f"{ASSETS_URL}/{PARKING_AREAS_JSON}", dir=TMP)
  64. safe_download(url=f"{ASSETS_URL}/{PARKING_MODEL}", dir=TMP)
  65. cap = cv2.VideoCapture(str(TMP / PARKING_VIDEO))
  66. assert cap.isOpened(), "Error reading video file"
  67. parkingmanager = solutions.ParkingManagement(
  68. json_file=str(TMP / PARKING_AREAS_JSON), model=str(TMP / PARKING_MODEL), show=False
  69. )
  70. while cap.isOpened():
  71. success, im0 = cap.read()
  72. if not success:
  73. break
  74. _ = parkingmanager.process_data(im0)
  75. cap.release()
  76. @pytest.mark.slow
  77. def test_instance_segmentation():
  78. """Test the instance segmentation solution."""
  79. from ultralytics.utils.plotting import Annotator, colors
  80. model = YOLO(WEIGHTS_DIR / "yolo11n-seg.pt")
  81. names = model.names
  82. cap = cv2.VideoCapture(TMP / DEMO_VIDEO)
  83. assert cap.isOpened(), "Error reading video file"
  84. while cap.isOpened():
  85. success, im0 = cap.read()
  86. if not success:
  87. break
  88. results = model.predict(im0)
  89. annotator = Annotator(im0, line_width=2)
  90. if results[0].masks is not None:
  91. clss = results[0].boxes.cls.cpu().tolist()
  92. masks = results[0].masks.xy
  93. for mask, cls in zip(masks, clss):
  94. color = colors(int(cls), True)
  95. annotator.seg_bbox(mask=mask, mask_color=color, label=names[int(cls)])
  96. cap.release()
  97. cv2.destroyAllWindows()
  98. @pytest.mark.slow
  99. def test_streamlit_predict():
  100. """Test streamlit predict live inference solution."""
  101. solutions.Inference().inference()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...