Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_integrations.py 4.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. from pathlib import Path
  4. import pytest
  5. from ultralytics import YOLO, download
  6. from ultralytics.utils import ASSETS, DATASETS_DIR, ROOT, SETTINGS, WEIGHTS_DIR
  7. from ultralytics.utils.checks import check_requirements
  8. MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path
  9. CFG = 'yolov8n.yaml'
  10. SOURCE = ASSETS / 'bus.jpg'
  11. TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
  12. @pytest.mark.skipif(not check_requirements('ray', install=False), reason='ray[tune] not installed')
  13. def test_model_ray_tune():
  14. """Tune YOLO model with Ray optimization library."""
  15. YOLO('yolov8n-cls.yaml').tune(use_ray=True,
  16. data='imagenet10',
  17. grace_period=1,
  18. iterations=1,
  19. imgsz=32,
  20. epochs=1,
  21. plots=False,
  22. device='cpu')
  23. @pytest.mark.skipif(not check_requirements('mlflow', install=False), reason='mlflow not installed')
  24. def test_mlflow():
  25. """Test training with MLflow tracking enabled."""
  26. SETTINGS['mlflow'] = True
  27. YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=3, plots=False, device='cpu')
  28. @pytest.mark.skipif(not check_requirements('tritonclient', install=False), reason='tritonclient[all] not installed')
  29. def test_triton():
  30. """Test NVIDIA Triton Server functionalities."""
  31. check_requirements('tritonclient[all]')
  32. import subprocess
  33. import time
  34. from tritonclient.http import InferenceServerClient # noqa
  35. # Create variables
  36. model_name = 'yolo'
  37. triton_repo_path = TMP / 'triton_repo'
  38. triton_model_path = triton_repo_path / model_name
  39. # Export model to ONNX
  40. f = YOLO(MODEL).export(format='onnx', dynamic=True)
  41. # Prepare Triton repo
  42. (triton_model_path / '1').mkdir(parents=True, exist_ok=True)
  43. Path(f).rename(triton_model_path / '1' / 'model.onnx')
  44. (triton_model_path / 'config.pbtxt').touch()
  45. # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
  46. tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
  47. # Pull the image
  48. subprocess.call(f'docker pull {tag}', shell=True)
  49. # Run the Triton server and capture the container ID
  50. container_id = subprocess.check_output(
  51. f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
  52. shell=True).decode('utf-8').strip()
  53. # Wait for the Triton server to start
  54. triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
  55. # Wait until model is ready
  56. for _ in range(10):
  57. with contextlib.suppress(Exception):
  58. assert triton_client.is_model_ready(model_name)
  59. break
  60. time.sleep(1)
  61. # Check Triton inference
  62. YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference
  63. # Kill and remove the container at the end of the test
  64. subprocess.call(f'docker kill {container_id}', shell=True)
  65. @pytest.mark.skipif(not check_requirements('pycocotools', install=False), reason='pycocotools not installed')
  66. def test_pycocotools():
  67. """Validate model predictions using pycocotools."""
  68. from ultralytics.models.yolo.detect import DetectionValidator
  69. from ultralytics.models.yolo.pose import PoseValidator
  70. from ultralytics.models.yolo.segment import SegmentationValidator
  71. # Download annotations after each dataset downloads first
  72. url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
  73. args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}
  74. validator = DetectionValidator(args=args)
  75. validator()
  76. validator.is_coco = True
  77. download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations')
  78. _ = validator.eval_json(validator.stats)
  79. args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64}
  80. validator = SegmentationValidator(args=args)
  81. validator()
  82. validator.is_coco = True
  83. download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations')
  84. _ = validator.eval_json(validator.stats)
  85. args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64}
  86. validator = PoseValidator(args=args)
  87. validator()
  88. validator.is_coco = True
  89. download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations')
  90. _ = validator.eval_json(validator.stats)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...