Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. import os
  2. import socket
  3. from functools import wraps
  4. from super_gradients.common.environment.device_utils import device_config
  5. from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
  6. from super_gradients.common.environment.argparse_utils import pop_local_rank
  7. def init_trainer():
  8. """
  9. Initialize the super_gradients environment.
  10. This function should be the first thing to be called by any code running super_gradients.
  11. """
  12. register_hydra_resolvers()
  13. pop_local_rank()
  14. def is_distributed() -> bool:
  15. """Check if current process is a DDP subprocess."""
  16. return device_config.assigned_rank >= 0
  17. def is_launched_using_sg():
  18. """Check if the current process is a subprocess launched using SG restart_script_with_ddp"""
  19. return os.environ.get("TORCHELASTIC_RUN_ID") == "sg_initiated"
  20. def is_main_process():
  21. """Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...).
  22. The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started.
  23. The rule is as follow:
  24. - If not DDP: main process is current process
  25. - If DDP launched using SuperGradients: main process is the launching process (rank=-1)
  26. - If DDP launched with torch: main process is rank 0
  27. """
  28. if not is_distributed(): # If no DDP, or DDP launching process
  29. return True
  30. elif (
  31. device_config.assigned_rank == 0 and not is_launched_using_sg()
  32. ): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
  33. return True
  34. else:
  35. return False
  36. def multi_process_safe(func):
  37. """
  38. A decorator for making sure a function runs only in main process.
  39. If not in DDP mode (local_rank = -1), the function will run.
  40. If in DDP mode, the function will run only in the main process (local_rank = 0)
  41. This works only for functions with no return value
  42. """
  43. def do_nothing(*args, **kwargs):
  44. pass
  45. @wraps(func)
  46. def wrapper(*args, **kwargs):
  47. if device_config.assigned_rank <= 0:
  48. return func(*args, **kwargs)
  49. else:
  50. return do_nothing(*args, **kwargs)
  51. return wrapper
  52. def find_free_port() -> int:
  53. """Find an available port of current machine/node.
  54. Note: there is still a chance the port could be taken by other processes."""
  55. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
  56. # Binding to port 0 will cause the OS to find an available port for us
  57. sock.bind(("", 0))
  58. _ip, port = sock.getsockname()
  59. return port
Discard
Tip!

Press p or to see the previous file or, n or to see the next file