Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

device.py 7.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
  1. import ctypes
  2. import itertools
  3. import os
  4. from typing import List
  5. import onnxruntime as rt
  6. from .. import appargs as lib_appargs
  7. class ORTDeviceInfo:
  8. """
  9. Represents picklable ONNXRuntime device info
  10. """
  11. def __init__(self, index=None, execution_provider=None, name=None, total_memory=None, free_memory=None):
  12. self._index : int = index
  13. self._execution_provider : str = execution_provider
  14. self._name : str = name
  15. self._total_memory : int = total_memory
  16. self._free_memory : int = free_memory
  17. def __getstate__(self):
  18. return self.__dict__.copy()
  19. def __setstate__(self, d):
  20. self.__init__()
  21. self.__dict__.update(d)
  22. def is_cpu(self) -> bool: return self._index == -1
  23. def get_index(self) -> int:
  24. return self._index
  25. def get_execution_provider(self) -> str:
  26. return self._execution_provider
  27. def get_name(self) -> str:
  28. return self._name
  29. def get_total_memory(self) -> int:
  30. return self._total_memory
  31. def get_free_memory(self) -> int:
  32. return self._free_memory
  33. def __eq__(self, other):
  34. if self is not None and other is not None and isinstance(self, ORTDeviceInfo) and isinstance(other, ORTDeviceInfo):
  35. return self._index == other._index
  36. return False
  37. def __hash__(self):
  38. return self._index
  39. def __str__(self):
  40. if self.is_cpu():
  41. return f"CPU"
  42. else:
  43. ep = self.get_execution_provider()
  44. if ep == 'CUDAExecutionProvider':
  45. return f"[{self._index}] {self._name} [{(self._total_memory / 1024**3) :.3}Gb] [CUDA]"
  46. elif ep == 'DmlExecutionProvider':
  47. return f"[{self._index}] {self._name} [{(self._total_memory / 1024**3) :.3}Gb] [DirectX12]"
  48. def __repr__(self):
  49. return f'{self.__class__.__name__} object: ' + self.__str__()
  50. _ort_devices_info = None
  51. def get_cpu_device_info() -> ORTDeviceInfo:
  52. return ORTDeviceInfo(index=-1, execution_provider='CPUExecutionProvider', name='CPU', total_memory=0, free_memory=0)
  53. def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[ORTDeviceInfo]:
  54. """
  55. returns a list of available ORTDeviceInfo
  56. """
  57. devices = []
  58. if not cpu_only:
  59. global _ort_devices_info
  60. if _ort_devices_info is None:
  61. _initialize_ort_devices_info()
  62. _ort_devices_info = []
  63. for i in range ( int(os.environ.get('ORT_DEVICES_COUNT',0)) ):
  64. _ort_devices_info.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']),
  65. execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'],
  66. name=os.environ[f'ORT_DEVICE_{i}_NAME'],
  67. total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']),
  68. free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']),
  69. ) )
  70. devices += _ort_devices_info
  71. if include_cpu:
  72. devices.append(get_cpu_device_info())
  73. return devices
  74. def _initialize_ort_devices_info():
  75. """
  76. Determine available ORT devices, and place info about them to os.environ,
  77. they will be available in spawned subprocesses.
  78. Using only python ctypes and default lib provided with NVIDIA drivers.
  79. """
  80. if int(os.environ.get('ORT_DEVICES_INITIALIZED', 0)) == 0:
  81. os.environ['ORT_DEVICES_INITIALIZED'] = '1'
  82. os.environ['ORT_DEVICES_COUNT'] = '0'
  83. devices = []
  84. prs = rt.get_available_providers()
  85. if not lib_appargs.get_arg_bool('NO_CUDA') and 'CUDAExecutionProvider' in prs:
  86. os.environ['CUDA_​CACHE_​MAXSIZE'] = '2147483647'
  87. try:
  88. libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
  89. for libname in libnames:
  90. try:
  91. cuda = ctypes.CDLL(libname)
  92. except:
  93. continue
  94. else:
  95. break
  96. else:
  97. return
  98. nGpus = ctypes.c_int()
  99. name = b' ' * 200
  100. cc_major = ctypes.c_int()
  101. cc_minor = ctypes.c_int()
  102. freeMem = ctypes.c_size_t()
  103. totalMem = ctypes.c_size_t()
  104. device = ctypes.c_int()
  105. context = ctypes.c_void_p()
  106. if cuda.cuInit(0) == 0 and \
  107. cuda.cuDeviceGetCount(ctypes.byref(nGpus)) == 0:
  108. for i in range(nGpus.value):
  109. if cuda.cuDeviceGet(ctypes.byref(device), i) != 0 or \
  110. cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) != 0 or \
  111. cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != 0:
  112. continue
  113. if cuda.cuCtxCreate_v2(ctypes.byref(context), 0, device) == 0:
  114. if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0:
  115. cc = cc_major.value * 10 + cc_minor.value
  116. devices.append ({'index' : i,
  117. 'execution_provider' : 'CUDAExecutionProvider',
  118. 'name' : name.split(b'\0', 1)[0].decode(),
  119. 'total_mem' : totalMem.value,
  120. 'free_mem' : freeMem.value,
  121. })
  122. cuda.cuCtxDetach(context)
  123. except Exception as e:
  124. print(f'CUDA devices initialization error: {e}')
  125. if 'DmlExecutionProvider' in prs:
  126. # onnxruntime-directml has no device enumeration API for users. Thus the code must follow the same logic
  127. # as here https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/dml/dml_provider_factory.cc
  128. from xlib.api.win32 import dxgi as lib_dxgi
  129. dxgi_factory = lib_dxgi.create_DXGIFactory4()
  130. if dxgi_factory is not None:
  131. for i in itertools.count():
  132. adapter = dxgi_factory.enum_adapters1(i)
  133. if adapter is not None:
  134. desc = adapter.get_desc1()
  135. if desc.Flags != lib_dxgi.DXGI_ADAPTER_FLAG.DXGI_ADAPTER_FLAG_SOFTWARE and \
  136. not (desc.VendorId == 0x1414 and desc.DeviceId == 0x8c):
  137. devices.append ({'index' : i,
  138. 'execution_provider' : 'DmlExecutionProvider',
  139. 'name' : desc.Description,
  140. 'total_mem' : desc.DedicatedVideoMemory,
  141. 'free_mem' : desc.DedicatedVideoMemory,
  142. })
  143. adapter.Release()
  144. else:
  145. break
  146. dxgi_factory.Release()
  147. os.environ['ORT_DEVICES_COUNT'] = str(len(devices))
  148. for i, device in enumerate(devices):
  149. os.environ[f'ORT_DEVICE_{i}_INDEX'] = str(device['index'])
  150. os.environ[f'ORT_DEVICE_{i}_EP'] = device['execution_provider']
  151. os.environ[f'ORT_DEVICE_{i}_NAME'] = device['name']
  152. os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
  153. os.environ[f'ORT_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
  154. _initialize_ort_devices_info()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...