Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

extension.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. import os
  2. import sys
  3. import torch
  4. from ._internally_replaced_utils import _get_extension_path
  5. _HAS_OPS = False
  6. def _has_ops():
  7. return False
  8. try:
  9. # On Windows Python-3.8.x has `os.add_dll_directory` call,
  10. # which is called to configure dll search path.
  11. # To find cuda related dlls we need to make sure the
  12. # conda environment/bin path is configured Please take a look:
  13. # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
  14. # Please note: if some path can't be added using add_dll_directory we simply ignore this path
  15. if os.name == "nt" and sys.version_info < (3, 9):
  16. env_path = os.environ["PATH"]
  17. path_arr = env_path.split(";")
  18. for path in path_arr:
  19. if os.path.exists(path):
  20. try:
  21. os.add_dll_directory(path) # type: ignore[attr-defined]
  22. except Exception:
  23. pass
  24. lib_path = _get_extension_path("_C")
  25. torch.ops.load_library(lib_path)
  26. _HAS_OPS = True
  27. def _has_ops(): # noqa: F811
  28. return True
  29. except (ImportError, OSError):
  30. pass
  31. def _assert_has_ops():
  32. if not _has_ops():
  33. raise RuntimeError(
  34. "Couldn't load custom C++ ops. This can happen if your PyTorch and "
  35. "torchvision versions are incompatible, or if you had errors while compiling "
  36. "torchvision from source. For further information on the compatible versions, check "
  37. "https://github.com/pytorch/vision#installation for the compatibility matrix. "
  38. "Please check your PyTorch version with torch.__version__ and your torchvision "
  39. "version with torchvision.__version__ and verify if they are compatible, and if not "
  40. "please reinstall torchvision so that it matches your PyTorch install."
  41. )
  42. def _check_cuda_version():
  43. """
  44. Make sure that CUDA versions match between the pytorch install and torchvision install
  45. """
  46. if not _HAS_OPS:
  47. return -1
  48. from torch.version import cuda as torch_version_cuda
  49. _version = torch.ops.torchvision._cuda_version()
  50. if _version != -1 and torch_version_cuda is not None:
  51. tv_version = str(_version)
  52. if int(tv_version) < 10000:
  53. tv_major = int(tv_version[0])
  54. tv_minor = int(tv_version[2])
  55. else:
  56. tv_major = int(tv_version[0:2])
  57. tv_minor = int(tv_version[3])
  58. t_version = torch_version_cuda.split(".")
  59. t_major = int(t_version[0])
  60. t_minor = int(t_version[1])
  61. if t_major != tv_major:
  62. raise RuntimeError(
  63. "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
  64. f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
  65. f"CUDA Version={tv_major}.{tv_minor}. "
  66. "Please reinstall the torchvision that matches your PyTorch install."
  67. )
  68. return _version
  69. def _load_library(lib_name):
  70. lib_path = _get_extension_path(lib_name)
  71. torch.ops.load_library(lib_path)
  72. _check_cuda_version()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...