Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_solutions.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import cv2
  3. import pytest
  4. from tests import TMP
  5. from ultralytics import YOLO, solutions
  6. from ultralytics.utils import ASSETS_URL, WEIGHTS_DIR
  7. from ultralytics.utils.downloads import safe_download
  8. DEMO_VIDEO = "solutions_ci_demo.mp4"
  9. POSE_VIDEO = "solution_ci_pose_demo.mp4"
  10. @pytest.mark.slow
  11. def test_major_solutions():
  12. """Test the object counting, heatmap, speed estimation, trackzone and queue management solution."""
  13. safe_download(url=f"{ASSETS_URL}/{DEMO_VIDEO}", dir=TMP)
  14. cap = cv2.VideoCapture(str(TMP / DEMO_VIDEO))
  15. assert cap.isOpened(), "Error reading video file"
  16. region_points = [(20, 400), (1080, 400), (1080, 360), (20, 360)]
  17. counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter
  18. heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps
  19. heatmap_count = solutions.Heatmap(
  20. colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False, region=region_points
  21. ) # Test heatmaps with object counting
  22. speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager
  23. queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation
  24. line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics
  25. pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics
  26. bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics
  27. area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics
  28. trackzone = solutions.TrackZone(region=region_points, model="yolo11n.pt", show=False) # Test trackzone
  29. frame_count = 0 # Required for analytics
  30. while cap.isOpened():
  31. success, im0 = cap.read()
  32. if not success:
  33. break
  34. frame_count += 1
  35. original_im0 = im0.copy()
  36. _ = counter.count(original_im0.copy())
  37. _ = heatmap.generate_heatmap(original_im0.copy())
  38. _ = heatmap_count.generate_heatmap(original_im0.copy())
  39. _ = speed.estimate_speed(original_im0.copy())
  40. _ = queue.process_queue(original_im0.copy())
  41. _ = line_analytics.process_data(original_im0.copy(), frame_count)
  42. _ = pie_analytics.process_data(original_im0.copy(), frame_count)
  43. _ = bar_analytics.process_data(original_im0.copy(), frame_count)
  44. _ = area_analytics.process_data(original_im0.copy(), frame_count)
  45. _ = trackzone.trackzone(original_im0.copy())
  46. cap.release()
  47. # Test workouts monitoring
  48. safe_download(url=f"{ASSETS_URL}/{POSE_VIDEO}", dir=TMP)
  49. cap = cv2.VideoCapture(str(TMP / POSE_VIDEO))
  50. assert cap.isOpened(), "Error reading video file"
  51. gym = solutions.AIGym(kpts=[5, 11, 13], show=False)
  52. while cap.isOpened():
  53. success, im0 = cap.read()
  54. if not success:
  55. break
  56. _ = gym.monitor(im0)
  57. cap.release()
  58. @pytest.mark.slow
  59. def test_instance_segmentation():
  60. """Test the instance segmentation solution."""
  61. from ultralytics.utils.plotting import Annotator, colors
  62. model = YOLO(WEIGHTS_DIR / "yolo11n-seg.pt")
  63. names = model.names
  64. cap = cv2.VideoCapture(TMP / DEMO_VIDEO)
  65. assert cap.isOpened(), "Error reading video file"
  66. while cap.isOpened():
  67. success, im0 = cap.read()
  68. if not success:
  69. break
  70. results = model.predict(im0)
  71. annotator = Annotator(im0, line_width=2)
  72. if results[0].masks is not None:
  73. clss = results[0].boxes.cls.cpu().tolist()
  74. masks = results[0].masks.xy
  75. for mask, cls in zip(masks, clss):
  76. color = colors(int(cls), True)
  77. annotator.seg_bbox(mask=mask, mask_color=color, label=names[int(cls)])
  78. cap.release()
  79. cv2.destroyAllWindows()
  80. @pytest.mark.slow
  81. def test_streamlit_predict():
  82. """Test streamlit predict live inference solution."""
  83. solutions.Inference().inference()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...