1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
|
- import argparse
- import importlib
- import os
- import socket
- import sys
- from functools import wraps
- from typing import Any
- from omegaconf import OmegaConf
- from super_gradients.common.environment import environment_config
- class TerminalColours:
- """
- Usage: https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python?page=1&tab=votes#tab-top
- """
- HEADER = "\033[95m"
- OKBLUE = "\033[94m"
- OKCYAN = "\033[96m"
- OKGREEN = "\033[92m"
- WARNING = "\033[93m"
- FAIL = "\033[91m"
- ENDC = "\033[0m"
- BOLD = "\033[1m"
- UNDERLINE = "\033[4m"
- class ColouredTextFormatter:
- @staticmethod
- def print_coloured_text(text: str, colour: str):
- """
- Prints a text with colour ascii characters.
- """
- return print("".join([colour, text, TerminalColours.ENDC]))
- def get_cls(cls_path):
- """
- A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.
- usage:
- class_of_optimizer: ${class:torch.optim.Adam}
- """
- module = ".".join(cls_path.split(".")[:-1])
- name = cls_path.split(".")[-1]
- importlib.import_module(module)
- return getattr(sys.modules[module], name)
- def get_environ_as_type(environment_variable_name: str, default=None, cast_to_type: type = str) -> object:
- """
- Tries to get an environment variable and cast it into a requested type.
- :return: cast_to_type object, or None if failed.
- :raises ValueError: If the value could not be casted into type 'cast_to_type'
- """
- value = os.environ.get(environment_variable_name, default)
- if value is not None:
- try:
- return cast_to_type(value)
- except Exception as e:
- print(e)
- raise ValueError(
- f"Failed to cast environment variable {environment_variable_name} to type {cast_to_type}: the value {value} is not a valid {cast_to_type}"
- )
- return
- def hydra_output_dir_resolver(ckpt_root_dir, experiment_name):
- if ckpt_root_dir is None:
- output_dir_path = environment_config.PKG_CHECKPOINTS_DIR + os.path.sep + experiment_name
- else:
- output_dir_path = ckpt_root_dir + os.path.sep + experiment_name
- return output_dir_path
- def init_trainer():
- """
- Initialize the super_gradients environment.
- This function should be the first thing to be called by any code running super_gradients.
- It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
- """
- if not environment_config.INIT_TRAINER:
- register_hydra_resolvers()
- # We pop local_rank if it was specified in the args, because it would break
- args_local_rank = pop_arg("local_rank", default_value=-1)
- # Set local_rank with priority order (env variable > args.local_rank > args.default_value)
- environment_config.DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
- environment_config.INIT_TRAINER = True
- def register_hydra_resolvers():
- """Register all the hydra resolvers required for the super-gradients recipes."""
- OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True)
- OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True)
- OmegaConf.register_new_resolver("add", lambda *args: sum(args), replace=True)
- OmegaConf.register_new_resolver("cond", lambda boolean, x, y: x if boolean else y, replace=True)
- OmegaConf.register_new_resolver("getitem", lambda container, key: container[key], replace=True) # get item from a container (list, dict...)
- OmegaConf.register_new_resolver("first", lambda lst: lst[0], replace=True) # get the first item from a list
- OmegaConf.register_new_resolver("last", lambda lst: lst[-1], replace=True) # get the last item from a list
- def pop_arg(arg_name: str, default_value: Any = None) -> Any:
- """Get the specified args and remove them from argv"""
- parser = argparse.ArgumentParser()
- parser.add_argument(f"--{arg_name}", default=default_value)
- args, _ = parser.parse_known_args()
- # Remove the ddp args to not have a conflict with the use of hydra
- for val in filter(lambda x: x.startswith(f"--{arg_name}"), sys.argv):
- environment_config.EXTRA_ARGS.append(val)
- sys.argv.remove(val)
- return vars(args)[arg_name]
- def is_distributed() -> bool:
- return environment_config.DDP_LOCAL_RANK >= 0
- def is_rank_0() -> bool:
- """Check if the node was launched with torch.distributed.launch and if the node is of rank 0"""
- return os.getenv("LOCAL_RANK") == "0"
- def is_launched_using_sg():
- """Check if the current process is a subprocess launched using SG restart_script_with_ddp"""
- return os.environ.get("TORCHELASTIC_RUN_ID") == "sg_initiated"
- def is_main_process():
- """Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...).
- The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started.
- The rule is as follow:
- - If not DDP: main process is current process
- - If DDP launched using SuperGradients: main process is the launching process (rank=-1)
- - If DDP launched with torch: main process is rank 0
- """
- if not is_distributed(): # If no DDP, or DDP launching process
- return True
- elif is_rank_0() and not is_launched_using_sg(): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
- return True
- else:
- return False
- def multi_process_safe(func):
- """
- A decorator for making sure a function runs only in main process.
- If not in DDP mode (local_rank = -1), the function will run.
- If in DDP mode, the function will run only in the main process (local_rank = 0)
- This works only for functions with no return value
- """
- def do_nothing(*args, **kwargs):
- pass
- @wraps(func)
- def wrapper(*args, **kwargs):
- if environment_config.DDP_LOCAL_RANK <= 0:
- return func(*args, **kwargs)
- else:
- return do_nothing(*args, **kwargs)
- return wrapper
- def find_free_port() -> int:
- """Find an available port of current machine/node.
- Note: there is still a chance the port could be taken by other processes."""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
- # Binding to port 0 will cause the OS to find an available port for us
- sock.bind(("", 0))
- _ip, port = sock.getsockname()
- return port
|