Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  1. # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
  2. from typing import List
  3. from urllib.parse import urlsplit
  4. import numpy as np
  5. class TritonRemoteModel:
  6. """
  7. Client for interacting with a remote Triton Inference Server model.
  8. This class provides a convenient interface for sending inference requests to a Triton Inference Server
  9. and processing the responses.
  10. Attributes:
  11. endpoint (str): The name of the model on the Triton server.
  12. url (str): The URL of the Triton server.
  13. triton_client: The Triton client (either HTTP or gRPC).
  14. InferInput: The input class for the Triton client.
  15. InferRequestedOutput: The output request class for the Triton client.
  16. input_formats (List[str]): The data types of the model inputs.
  17. np_input_formats (List[type]): The numpy data types of the model inputs.
  18. input_names (List[str]): The names of the model inputs.
  19. output_names (List[str]): The names of the model outputs.
  20. metadata: The metadata associated with the model.
  21. Methods:
  22. __call__: Call the model with the given inputs and return the outputs.
  23. Examples:
  24. Initialize a Triton client with HTTP
  25. >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
  26. Make inference with numpy arrays
  27. >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
  28. """
  29. def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
  30. """
  31. Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.
  32. Arguments may be provided individually or parsed from a collective 'url' argument of the form
  33. <scheme>://<netloc>/<endpoint>/<task_name>
  34. Args:
  35. url (str): The URL of the Triton server.
  36. endpoint (str): The name of the model on the Triton server.
  37. scheme (str): The communication scheme ('http' or 'grpc').
  38. Examples:
  39. >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
  40. >>> model = TritonRemoteModel(url="http://localhost:8000/yolov8")
  41. """
  42. if not endpoint and not scheme: # Parse all args from URL string
  43. splits = urlsplit(url)
  44. endpoint = splits.path.strip("/").split("/")[0]
  45. scheme = splits.scheme
  46. url = splits.netloc
  47. self.endpoint = endpoint
  48. self.url = url
  49. # Choose the Triton client based on the communication scheme
  50. if scheme == "http":
  51. import tritonclient.http as client # noqa
  52. self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
  53. config = self.triton_client.get_model_config(endpoint)
  54. else:
  55. import tritonclient.grpc as client # noqa
  56. self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
  57. config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
  58. # Sort output names alphabetically, i.e. 'output0', 'output1', etc.
  59. config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
  60. # Define model attributes
  61. type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
  62. self.InferRequestedOutput = client.InferRequestedOutput
  63. self.InferInput = client.InferInput
  64. self.input_formats = [x["data_type"] for x in config["input"]]
  65. self.np_input_formats = [type_map[x] for x in self.input_formats]
  66. self.input_names = [x["name"] for x in config["input"]]
  67. self.output_names = [x["name"] for x in config["output"]]
  68. self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
  69. def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
  70. """
  71. Call the model with the given inputs.
  72. Args:
  73. *inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type
  74. for the corresponding model input.
  75. Returns:
  76. (List[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list
  77. corresponds to one of the model's output tensors.
  78. Examples:
  79. >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
  80. >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
  81. """
  82. infer_inputs = []
  83. input_format = inputs[0].dtype
  84. for i, x in enumerate(inputs):
  85. if x.dtype != self.np_input_formats[i]:
  86. x = x.astype(self.np_input_formats[i])
  87. infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
  88. infer_input.set_data_from_numpy(x)
  89. infer_inputs.append(infer_input)
  90. infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
  91. outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
  92. return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
Discard
Tip!

Press p or to see the previous file or, n or to see the next file