Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_engine.py 4.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import sys
  3. from unittest import mock
  4. from tests import MODEL
  5. from ultralytics import YOLO
  6. from ultralytics.cfg import get_cfg
  7. from ultralytics.engine.exporter import Exporter
  8. from ultralytics.models.yolo import classify, detect, segment
  9. from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
  10. def test_func(*args): # noqa
  11. """Test function callback for evaluating YOLO model performance metrics."""
  12. print("callback test passed")
  13. def test_export():
  14. """Test model exporting functionality by adding a callback and verifying its execution."""
  15. exporter = Exporter()
  16. exporter.add_callback("on_export_start", test_func)
  17. assert test_func in exporter.callbacks["on_export_start"], "callback test failed"
  18. f = exporter(model=YOLO("yolo11n.yaml").model)
  19. YOLO(f)(ASSETS) # exported model inference
  20. def test_detect():
  21. """Test YOLO object detection training, validation, and prediction functionality."""
  22. overrides = {"data": "coco8.yaml", "model": "yolo11n.yaml", "imgsz": 32, "epochs": 1, "save": False}
  23. cfg = get_cfg(DEFAULT_CFG)
  24. cfg.data = "coco8.yaml"
  25. cfg.imgsz = 32
  26. # Trainer
  27. trainer = detect.DetectionTrainer(overrides=overrides)
  28. trainer.add_callback("on_train_start", test_func)
  29. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  30. trainer.train()
  31. # Validator
  32. val = detect.DetectionValidator(args=cfg)
  33. val.add_callback("on_val_start", test_func)
  34. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  35. val(model=trainer.best) # validate best.pt
  36. # Predictor
  37. pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
  38. pred.add_callback("on_predict_start", test_func)
  39. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  40. # Confirm there is no issue with sys.argv being empty
  41. with mock.patch.object(sys, "argv", []):
  42. result = pred(source=ASSETS, model=MODEL)
  43. assert len(result), "predictor test failed"
  44. # Test resume functionality
  45. overrides["resume"] = trainer.last
  46. trainer = detect.DetectionTrainer(overrides=overrides)
  47. try:
  48. trainer.train()
  49. except Exception as e:
  50. print(f"Expected exception caught: {e}")
  51. return
  52. raise Exception("Resume test failed!")
  53. def test_segment():
  54. """Test image segmentation training, validation, and prediction pipelines using YOLO models."""
  55. overrides = {"data": "coco8-seg.yaml", "model": "yolo11n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}
  56. cfg = get_cfg(DEFAULT_CFG)
  57. cfg.data = "coco8-seg.yaml"
  58. cfg.imgsz = 32
  59. # Trainer
  60. trainer = segment.SegmentationTrainer(overrides=overrides)
  61. trainer.add_callback("on_train_start", test_func)
  62. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  63. trainer.train()
  64. # Validator
  65. val = segment.SegmentationValidator(args=cfg)
  66. val.add_callback("on_val_start", test_func)
  67. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  68. val(model=trainer.best) # validate best.pt
  69. # Predictor
  70. pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
  71. pred.add_callback("on_predict_start", test_func)
  72. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  73. result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolo11n-seg.pt")
  74. assert len(result), "predictor test failed"
  75. # Test resume functionality
  76. overrides["resume"] = trainer.last
  77. trainer = segment.SegmentationTrainer(overrides=overrides)
  78. try:
  79. trainer.train()
  80. except Exception as e:
  81. print(f"Expected exception caught: {e}")
  82. return
  83. raise Exception("Resume test failed!")
  84. def test_classify():
  85. """Test image classification including training, validation, and prediction phases."""
  86. overrides = {"data": "imagenet10", "model": "yolo11n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False}
  87. cfg = get_cfg(DEFAULT_CFG)
  88. cfg.data = "imagenet10"
  89. cfg.imgsz = 32
  90. # Trainer
  91. trainer = classify.ClassificationTrainer(overrides=overrides)
  92. trainer.add_callback("on_train_start", test_func)
  93. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  94. trainer.train()
  95. # Validator
  96. val = classify.ClassificationValidator(args=cfg)
  97. val.add_callback("on_val_start", test_func)
  98. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  99. val(model=trainer.best)
  100. # Predictor
  101. pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
  102. pred.add_callback("on_predict_start", test_func)
  103. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  104. result = pred(source=ASSETS, model=trainer.best)
  105. assert len(result), "predictor test failed"
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...