Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

InferenceSession.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. import onnx
  2. import onnxruntime as rt
  3. from io import BytesIO
  4. from .device import ORTDeviceInfo
  5. def InferenceSession_with_device(onnx_model_or_path, device_info : ORTDeviceInfo):
  6. """
  7. Construct onnxruntime.InferenceSession with this Device.
  8. device_info ORTDeviceInfo
  9. can raise Exception
  10. """
  11. if isinstance(onnx_model_or_path, onnx.ModelProto):
  12. b = BytesIO()
  13. onnx.save(onnx_model_or_path, b)
  14. onnx_model_or_path = b.getvalue()
  15. device_ep = device_info.get_execution_provider()
  16. if device_ep not in rt.get_available_providers():
  17. raise Exception(f'{device_ep} is not avaiable in onnxruntime')
  18. ep_flags = {}
  19. if device_ep in ['CUDAExecutionProvider','DmlExecutionProvider']:
  20. ep_flags['device_id'] = device_info.get_index()
  21. sess_options = rt.SessionOptions()
  22. sess_options.log_severity_level = 4
  23. sess_options.log_verbosity_level = -1
  24. if device_ep == 'DmlExecutionProvider':
  25. sess_options.enable_mem_pattern = False
  26. sess = rt.InferenceSession(onnx_model_or_path, providers=[ (device_ep, ep_flags) ], sess_options=sess_options)
  27. return sess
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...