Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_engine.py 4.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import sys
  3. from unittest import mock
  4. from tests import MODEL
  5. from ultralytics import YOLO
  6. from ultralytics.cfg import get_cfg
  7. from ultralytics.engine.exporter import Exporter
  8. from ultralytics.models.yolo import classify, detect, segment
  9. from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
  10. def test_func(*args): # noqa
  11. """Test function callback for evaluating YOLO model performance metrics."""
  12. print("callback test passed")
  13. def test_export():
  14. """Tests the model exporting function by adding a callback and asserting its execution."""
  15. exporter = Exporter()
  16. exporter.add_callback("on_export_start", test_func)
  17. assert test_func in exporter.callbacks["on_export_start"], "callback test failed"
  18. f = exporter(model=YOLO("yolo11n.yaml").model)
  19. YOLO(f)(ASSETS) # exported model inference
  20. def test_detect():
  21. """Test YOLO object detection training, validation, and prediction functionality."""
  22. overrides = {"data": "coco8.yaml", "model": "yolo11n.yaml", "imgsz": 32, "epochs": 1, "save": False}
  23. cfg = get_cfg(DEFAULT_CFG)
  24. cfg.data = "coco8.yaml"
  25. cfg.imgsz = 32
  26. # Trainer
  27. trainer = detect.DetectionTrainer(overrides=overrides)
  28. trainer.add_callback("on_train_start", test_func)
  29. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  30. trainer.train()
  31. # Validator
  32. val = detect.DetectionValidator(args=cfg)
  33. val.add_callback("on_val_start", test_func)
  34. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  35. val(model=trainer.best) # validate best.pt
  36. # Predictor
  37. pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
  38. pred.add_callback("on_predict_start", test_func)
  39. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  40. # Confirm there is no issue with sys.argv being empty.
  41. with mock.patch.object(sys, "argv", []):
  42. result = pred(source=ASSETS, model=MODEL)
  43. assert len(result), "predictor test failed"
  44. overrides["resume"] = trainer.last
  45. trainer = detect.DetectionTrainer(overrides=overrides)
  46. try:
  47. trainer.train()
  48. except Exception as e:
  49. print(f"Expected exception caught: {e}")
  50. return
  51. Exception("Resume test failed!")
  52. def test_segment():
  53. """Tests image segmentation training, validation, and prediction pipelines using YOLO models."""
  54. overrides = {"data": "coco8-seg.yaml", "model": "yolo11n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}
  55. cfg = get_cfg(DEFAULT_CFG)
  56. cfg.data = "coco8-seg.yaml"
  57. cfg.imgsz = 32
  58. # YOLO(CFG_SEG).train(**overrides) # works
  59. # Trainer
  60. trainer = segment.SegmentationTrainer(overrides=overrides)
  61. trainer.add_callback("on_train_start", test_func)
  62. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  63. trainer.train()
  64. # Validator
  65. val = segment.SegmentationValidator(args=cfg)
  66. val.add_callback("on_val_start", test_func)
  67. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  68. val(model=trainer.best) # validate best.pt
  69. # Predictor
  70. pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
  71. pred.add_callback("on_predict_start", test_func)
  72. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  73. result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolo11n-seg.pt")
  74. assert len(result), "predictor test failed"
  75. # Test resume
  76. overrides["resume"] = trainer.last
  77. trainer = segment.SegmentationTrainer(overrides=overrides)
  78. try:
  79. trainer.train()
  80. except Exception as e:
  81. print(f"Expected exception caught: {e}")
  82. return
  83. Exception("Resume test failed!")
  84. def test_classify():
  85. """Test image classification including training, validation, and prediction phases."""
  86. overrides = {"data": "imagenet10", "model": "yolo11n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False}
  87. cfg = get_cfg(DEFAULT_CFG)
  88. cfg.data = "imagenet10"
  89. cfg.imgsz = 32
  90. # YOLO(CFG_SEG).train(**overrides) # works
  91. # Trainer
  92. trainer = classify.ClassificationTrainer(overrides=overrides)
  93. trainer.add_callback("on_train_start", test_func)
  94. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  95. trainer.train()
  96. # Validator
  97. val = classify.ClassificationValidator(args=cfg)
  98. val.add_callback("on_val_start", test_func)
  99. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  100. val(model=trainer.best)
  101. # Predictor
  102. pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
  103. pred.add_callback("on_predict_start", test_func)
  104. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  105. result = pred(source=ASSETS, model=trainer.best)
  106. assert len(result), "predictor test failed"
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...