Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_integrations.py 6.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. import os
  4. import subprocess
  5. import time
  6. from pathlib import Path
  7. import pytest
  8. from tests import MODEL, SOURCE, TMP
  9. from ultralytics import YOLO, download
  10. from ultralytics.utils import DATASETS_DIR, SETTINGS
  11. from ultralytics.utils.checks import check_requirements
  12. @pytest.mark.skipif(not check_requirements("ray", install=False), reason="ray[tune] not installed")
  13. def test_model_ray_tune():
  14. """Tune YOLO model using Ray for hyperparameter optimization."""
  15. YOLO("yolo11n-cls.yaml").tune(
  16. use_ray=True, data="imagenet10", grace_period=1, iterations=1, imgsz=32, epochs=1, plots=False, device="cpu"
  17. )
  18. @pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed")
  19. def test_mlflow():
  20. """Test training with MLflow tracking enabled (see https://mlflow.org/ for details)."""
  21. SETTINGS["mlflow"] = True
  22. YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=3, plots=False, device="cpu")
  23. SETTINGS["mlflow"] = False
  24. @pytest.mark.skipif(True, reason="Test failing in scheduled CI https://github.com/ultralytics/ultralytics/pull/8868")
  25. @pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed")
  26. def test_mlflow_keep_run_active():
  27. """Ensure MLflow run status matches MLFLOW_KEEP_RUN_ACTIVE environment variable settings."""
  28. import mlflow
  29. SETTINGS["mlflow"] = True
  30. run_name = "Test Run"
  31. os.environ["MLFLOW_RUN"] = run_name
  32. # Test with MLFLOW_KEEP_RUN_ACTIVE=True
  33. os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "True"
  34. YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")
  35. status = mlflow.active_run().info.status
  36. assert status == "RUNNING", "MLflow run should be active when MLFLOW_KEEP_RUN_ACTIVE=True"
  37. run_id = mlflow.active_run().info.run_id
  38. # Test with MLFLOW_KEEP_RUN_ACTIVE=False
  39. os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "False"
  40. YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")
  41. status = mlflow.get_run(run_id=run_id).info.status
  42. assert status == "FINISHED", "MLflow run should be ended when MLFLOW_KEEP_RUN_ACTIVE=False"
  43. # Test with MLFLOW_KEEP_RUN_ACTIVE not set
  44. os.environ.pop("MLFLOW_KEEP_RUN_ACTIVE", None)
  45. YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")
  46. status = mlflow.get_run(run_id=run_id).info.status
  47. assert status == "FINISHED", "MLflow run should be ended by default when MLFLOW_KEEP_RUN_ACTIVE is not set"
  48. SETTINGS["mlflow"] = False
  49. @pytest.mark.skipif(not check_requirements("tritonclient", install=False), reason="tritonclient[all] not installed")
  50. def test_triton():
  51. """
  52. Test NVIDIA Triton Server functionalities with YOLO model.
  53. See https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver.
  54. """
  55. check_requirements("tritonclient[all]")
  56. from tritonclient.http import InferenceServerClient # noqa
  57. # Create variables
  58. model_name = "yolo"
  59. triton_repo = TMP / "triton_repo" # Triton repo path
  60. triton_model = triton_repo / model_name # Triton model path
  61. # Export model to ONNX
  62. f = YOLO(MODEL).export(format="onnx", dynamic=True)
  63. # Prepare Triton repo
  64. (triton_model / "1").mkdir(parents=True, exist_ok=True)
  65. Path(f).rename(triton_model / "1" / "model.onnx")
  66. (triton_model / "config.pbtxt").touch()
  67. # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
  68. tag = "nvcr.io/nvidia/tritonserver:23.09-py3" # 6.4 GB
  69. # Pull the image
  70. subprocess.call(f"docker pull {tag}", shell=True)
  71. # Run the Triton server and capture the container ID
  72. container_id = (
  73. subprocess.check_output(
  74. f"docker run -d --rm -v {triton_repo}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models",
  75. shell=True,
  76. )
  77. .decode("utf-8")
  78. .strip()
  79. )
  80. # Wait for the Triton server to start
  81. triton_client = InferenceServerClient(url="localhost:8000", verbose=False, ssl=False)
  82. # Wait until model is ready
  83. for _ in range(10):
  84. with contextlib.suppress(Exception):
  85. assert triton_client.is_model_ready(model_name)
  86. break
  87. time.sleep(1)
  88. # Check Triton inference
  89. YOLO(f"http://localhost:8000/{model_name}", "detect")(SOURCE) # exported model inference
  90. # Kill and remove the container at the end of the test
  91. subprocess.call(f"docker kill {container_id}", shell=True)
  92. @pytest.mark.skipif(not check_requirements("pycocotools", install=False), reason="pycocotools not installed")
  93. def test_pycocotools():
  94. """Validate YOLO model predictions on COCO dataset using pycocotools."""
  95. from ultralytics.models.yolo.detect import DetectionValidator
  96. from ultralytics.models.yolo.pose import PoseValidator
  97. from ultralytics.models.yolo.segment import SegmentationValidator
  98. # Download annotations after each dataset downloads first
  99. url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/"
  100. args = {"model": "yolo11n.pt", "data": "coco8.yaml", "save_json": True, "imgsz": 64}
  101. validator = DetectionValidator(args=args)
  102. validator()
  103. validator.is_coco = True
  104. download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8/annotations")
  105. _ = validator.eval_json(validator.stats)
  106. args = {"model": "yolo11n-seg.pt", "data": "coco8-seg.yaml", "save_json": True, "imgsz": 64}
  107. validator = SegmentationValidator(args=args)
  108. validator()
  109. validator.is_coco = True
  110. download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8-seg/annotations")
  111. _ = validator.eval_json(validator.stats)
  112. args = {"model": "yolo11n-pose.pt", "data": "coco8-pose.yaml", "save_json": True, "imgsz": 64}
  113. validator = PoseValidator(args=args)
  114. validator()
  115. validator.is_coco = True
  116. download(f"{url}person_keypoints_val2017.json", dir=DATASETS_DIR / "coco8-pose/annotations")
  117. _ = validator.eval_json(validator.stats)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...