Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#537 Quantization infra mods for different calibrators and learnable amax

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/AL-706-selective-qat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  1. import os
  2. import socket
  3. from functools import wraps
  4. from super_gradients.common.environment.argparse_utils import pop_arg
  5. from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
  6. DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=-1))
  7. INIT_TRAINER = False
  8. def init_trainer():
  9. """
  10. Initialize the super_gradients environment.
  11. This function should be the first thing to be called by any code running super_gradients.
  12. It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
  13. """
  14. global INIT_TRAINER, DDP_LOCAL_RANK
  15. if not INIT_TRAINER:
  16. register_hydra_resolvers()
  17. # We pop local_rank if it was specified in the args, because it would break
  18. args_local_rank = pop_arg("local_rank", default_value=-1)
  19. # Set local_rank with priority order (env variable > args.local_rank > args.default_value)
  20. DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
  21. INIT_TRAINER = True
  22. def is_distributed() -> bool:
  23. return DDP_LOCAL_RANK >= 0
  24. def is_rank_0() -> bool:
  25. """Check if the node was launched with torch.distributed.launch and if the node is of rank 0"""
  26. return os.getenv("LOCAL_RANK") == "0"
  27. def is_launched_using_sg():
  28. """Check if the current process is a subprocess launched using SG restart_script_with_ddp"""
  29. return os.environ.get("TORCHELASTIC_RUN_ID") == "sg_initiated"
  30. def is_main_process():
  31. """Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...).
  32. The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started.
  33. The rule is as follow:
  34. - If not DDP: main process is current process
  35. - If DDP launched using SuperGradients: main process is the launching process (rank=-1)
  36. - If DDP launched with torch: main process is rank 0
  37. """
  38. if not is_distributed(): # If no DDP, or DDP launching process
  39. return True
  40. elif is_rank_0() and not is_launched_using_sg(): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
  41. return True
  42. else:
  43. return False
  44. def multi_process_safe(func):
  45. """
  46. A decorator for making sure a function runs only in main process.
  47. If not in DDP mode (local_rank = -1), the function will run.
  48. If in DDP mode, the function will run only in the main process (local_rank = 0)
  49. This works only for functions with no return value
  50. """
  51. def do_nothing(*args, **kwargs):
  52. pass
  53. @wraps(func)
  54. def wrapper(*args, **kwargs):
  55. if DDP_LOCAL_RANK <= 0:
  56. return func(*args, **kwargs)
  57. else:
  58. return do_nothing(*args, **kwargs)
  59. return wrapper
  60. def find_free_port() -> int:
  61. """Find an available port of current machine/node.
  62. Note: there is still a chance the port could be taken by other processes."""
  63. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
  64. # Binding to port 0 will cause the OS to find an available port for us
  65. sock.bind(("", 0))
  66. _ip, port = sock.getsockname()
  67. return port
Discard
Tip!

Press p or to see the previous file or, n or to see the next file