Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#548 Split and rename the modules from super_gradients.common.environment

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-refactor_environment_package
@@ -4,7 +4,7 @@ from super_gradients.common.decorators import explicit_params_validation, single
 from super_gradients.common.aws_connection import AWSConnector
 from super_gradients.common.aws_connection import AWSConnector
 from super_gradients.common.data_connection import S3Connector
 from super_gradients.common.data_connection import S3Connector
 from super_gradients.common.data_interface import DatasetDataInterface, ADNNModelRepositoryDataInterfaces
 from super_gradients.common.data_interface import DatasetDataInterface, ADNNModelRepositoryDataInterfaces
-from super_gradients.common.environment.env_helpers import init_trainer, is_distributed
+from super_gradients.common.environment.ddp_utils import init_trainer, is_distributed
 from super_gradients.common.data_types import StrictLoad, DeepLearningTask, EvaluationType, MultiGPUMode, UpsampleMode
 from super_gradients.common.data_types import StrictLoad, DeepLearningTask, EvaluationType, MultiGPUMode, UpsampleMode
 from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
 from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
 
 
Discard
@@ -6,7 +6,7 @@ from io import StringIO
 import atexit
 import atexit
 from threading import Lock
 from threading import Lock
 
 
-from super_gradients.common.environment.env_helpers import multi_process_safe, is_main_process
+from super_gradients.common.environment.ddp_utils import multi_process_safe, is_main_process
 
 
 
 
 class BufferWriter:
 class BufferWriter:
Discard
@@ -3,7 +3,7 @@ import logging
 import atexit
 import atexit
 
 
 from super_gradients.common.auto_logging.console_logging import ConsoleSink
 from super_gradients.common.auto_logging.console_logging import ConsoleSink
-from super_gradients.common.environment.env_helpers import multi_process_safe, is_distributed
+from super_gradients.common.environment.ddp_utils import multi_process_safe, is_distributed
 from super_gradients.common.crash_handler.exception import ExceptionInfo
 from super_gradients.common.crash_handler.exception import ExceptionInfo
 
 
 try:
 try:
Discard
@@ -1,7 +1,6 @@
 """
 """
 This module is in charge of environment variables and consts.
 This module is in charge of environment variables and consts.
 """
 """
-from super_gradients.common.environment.environment_config import DDP_LOCAL_RANK
-from super_gradients.common.environment.env_helpers import init_trainer, is_distributed
+from super_gradients.common.environment.ddp_utils import init_trainer, is_distributed
 
 
-__all__ = ['DDP_LOCAL_RANK', 'init_trainer', 'is_distributed']
+__all__ = ["init_trainer", "is_distributed"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  1. import argparse
  2. import sys
  3. from typing import Any
  4. EXTRA_ARGS = []
  5. def pop_arg(arg_name: str, default_value: Any = None) -> Any:
  6. """Get the specified args and remove them from argv"""
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument(f"--{arg_name}", default=default_value)
  9. args, _ = parser.parse_known_args()
  10. # Remove the ddp args to not have a conflict with the use of hydra
  11. for val in filter(lambda x: x.startswith(f"--{arg_name}"), sys.argv):
  12. EXTRA_ARGS.append(val)
  13. sys.argv.remove(val)
  14. return vars(args)[arg_name]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  1. import os
  2. import socket
  3. from functools import wraps
  4. from super_gradients.common.environment.argparse_utils import pop_arg
  5. from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
  6. DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=-1))
  7. INIT_TRAINER = False
  8. def init_trainer():
  9. """
  10. Initialize the super_gradients environment.
  11. This function should be the first thing to be called by any code running super_gradients.
  12. It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
  13. """
  14. global INIT_TRAINER, DDP_LOCAL_RANK
  15. if not INIT_TRAINER:
  16. register_hydra_resolvers()
  17. # We pop local_rank if it was specified in the args, because it would break
  18. args_local_rank = pop_arg("local_rank", default_value=-1)
  19. # Set local_rank with priority order (env variable > args.local_rank > args.default_value)
  20. DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
  21. INIT_TRAINER = True
  22. def is_distributed() -> bool:
  23. return DDP_LOCAL_RANK >= 0
  24. def is_rank_0() -> bool:
  25. """Check if the node was launched with torch.distributed.launch and if the node is of rank 0"""
  26. return os.getenv("LOCAL_RANK") == "0"
  27. def is_launched_using_sg():
  28. """Check if the current process is a subprocess launched using SG restart_script_with_ddp"""
  29. return os.environ.get("TORCHELASTIC_RUN_ID") == "sg_initiated"
  30. def is_main_process():
  31. """Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...).
  32. The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started.
  33. The rule is as follow:
  34. - If not DDP: main process is current process
  35. - If DDP launched using SuperGradients: main process is the launching process (rank=-1)
  36. - If DDP launched with torch: main process is rank 0
  37. """
  38. if not is_distributed(): # If no DDP, or DDP launching process
  39. return True
  40. elif is_rank_0() and not is_launched_using_sg(): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
  41. return True
  42. else:
  43. return False
  44. def multi_process_safe(func):
  45. """
  46. A decorator for making sure a function runs only in main process.
  47. If not in DDP mode (local_rank = -1), the function will run.
  48. If in DDP mode, the function will run only in the main process (local_rank = 0)
  49. This works only for functions with no return value
  50. """
  51. def do_nothing(*args, **kwargs):
  52. pass
  53. @wraps(func)
  54. def wrapper(*args, **kwargs):
  55. if DDP_LOCAL_RANK <= 0:
  56. return func(*args, **kwargs)
  57. else:
  58. return do_nothing(*args, **kwargs)
  59. return wrapper
  60. def find_free_port() -> int:
  61. """Find an available port of current machine/node.
  62. Note: there is still a chance the port could be taken by other processes."""
  63. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
  64. # Binding to port 0 will cause the OS to find an available port for us
  65. sock.bind(("", 0))
  66. _ip, port = sock.getsockname()
  67. return port
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  1. import argparse
  2. import importlib
  3. import os
  4. import socket
  5. import sys
  6. from functools import wraps
  7. from typing import Any
  8. from omegaconf import OmegaConf
  9. from super_gradients.common.environment import environment_config
  10. class TerminalColours:
  11. """
  12. Usage: https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python?page=1&tab=votes#tab-top
  13. """
  14. HEADER = "\033[95m"
  15. OKBLUE = "\033[94m"
  16. OKCYAN = "\033[96m"
  17. OKGREEN = "\033[92m"
  18. WARNING = "\033[93m"
  19. FAIL = "\033[91m"
  20. ENDC = "\033[0m"
  21. BOLD = "\033[1m"
  22. UNDERLINE = "\033[4m"
  23. class ColouredTextFormatter:
  24. @staticmethod
  25. def print_coloured_text(text: str, colour: str):
  26. """
  27. Prints a text with colour ascii characters.
  28. """
  29. return print("".join([colour, text, TerminalColours.ENDC]))
  30. def get_cls(cls_path):
  31. """
  32. A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.
  33. usage:
  34. class_of_optimizer: ${class:torch.optim.Adam}
  35. """
  36. module = ".".join(cls_path.split(".")[:-1])
  37. name = cls_path.split(".")[-1]
  38. importlib.import_module(module)
  39. return getattr(sys.modules[module], name)
  40. def get_environ_as_type(environment_variable_name: str, default=None, cast_to_type: type = str) -> object:
  41. """
  42. Tries to get an environment variable and cast it into a requested type.
  43. :return: cast_to_type object, or None if failed.
  44. :raises ValueError: If the value could not be casted into type 'cast_to_type'
  45. """
  46. value = os.environ.get(environment_variable_name, default)
  47. if value is not None:
  48. try:
  49. return cast_to_type(value)
  50. except Exception as e:
  51. print(e)
  52. raise ValueError(
  53. f"Failed to cast environment variable {environment_variable_name} to type {cast_to_type}: the value {value} is not a valid {cast_to_type}"
  54. )
  55. return
  56. def hydra_output_dir_resolver(ckpt_root_dir, experiment_name):
  57. if ckpt_root_dir is None:
  58. output_dir_path = environment_config.PKG_CHECKPOINTS_DIR + os.path.sep + experiment_name
  59. else:
  60. output_dir_path = ckpt_root_dir + os.path.sep + experiment_name
  61. return output_dir_path
  62. def init_trainer():
  63. """
  64. Initialize the super_gradients environment.
  65. This function should be the first thing to be called by any code running super_gradients.
  66. It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
  67. """
  68. if not environment_config.INIT_TRAINER:
  69. register_hydra_resolvers()
  70. # We pop local_rank if it was specified in the args, because it would break
  71. args_local_rank = pop_arg("local_rank", default_value=-1)
  72. # Set local_rank with priority order (env variable > args.local_rank > args.default_value)
  73. environment_config.DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
  74. environment_config.INIT_TRAINER = True
  75. def register_hydra_resolvers():
  76. """Register all the hydra resolvers required for the super-gradients recipes."""
  77. OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True)
  78. OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True)
  79. OmegaConf.register_new_resolver("add", lambda *args: sum(args), replace=True)
  80. OmegaConf.register_new_resolver("cond", lambda boolean, x, y: x if boolean else y, replace=True)
  81. OmegaConf.register_new_resolver("getitem", lambda container, key: container[key], replace=True) # get item from a container (list, dict...)
  82. OmegaConf.register_new_resolver("first", lambda lst: lst[0], replace=True) # get the first item from a list
  83. OmegaConf.register_new_resolver("last", lambda lst: lst[-1], replace=True) # get the last item from a list
  84. def pop_arg(arg_name: str, default_value: Any = None) -> Any:
  85. """Get the specified args and remove them from argv"""
  86. parser = argparse.ArgumentParser()
  87. parser.add_argument(f"--{arg_name}", default=default_value)
  88. args, _ = parser.parse_known_args()
  89. # Remove the ddp args to not have a conflict with the use of hydra
  90. for val in filter(lambda x: x.startswith(f"--{arg_name}"), sys.argv):
  91. environment_config.EXTRA_ARGS.append(val)
  92. sys.argv.remove(val)
  93. return vars(args)[arg_name]
  94. def is_distributed() -> bool:
  95. return environment_config.DDP_LOCAL_RANK >= 0
  96. def is_rank_0() -> bool:
  97. """Check if the node was launched with torch.distributed.launch and if the node is of rank 0"""
  98. return os.getenv("LOCAL_RANK") == "0"
  99. def is_launched_using_sg():
  100. """Check if the current process is a subprocess launched using SG restart_script_with_ddp"""
  101. return os.environ.get("TORCHELASTIC_RUN_ID") == "sg_initiated"
  102. def is_main_process():
  103. """Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...).
  104. The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started.
  105. The rule is as follow:
  106. - If not DDP: main process is current process
  107. - If DDP launched using SuperGradients: main process is the launching process (rank=-1)
  108. - If DDP launched with torch: main process is rank 0
  109. """
  110. if not is_distributed(): # If no DDP, or DDP launching process
  111. return True
  112. elif is_rank_0() and not is_launched_using_sg(): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
  113. return True
  114. else:
  115. return False
  116. def multi_process_safe(func):
  117. """
  118. A decorator for making sure a function runs only in main process.
  119. If not in DDP mode (local_rank = -1), the function will run.
  120. If in DDP mode, the function will run only in the main process (local_rank = 0)
  121. This works only for functions with no return value
  122. """
  123. def do_nothing(*args, **kwargs):
  124. pass
  125. @wraps(func)
  126. def wrapper(*args, **kwargs):
  127. if environment_config.DDP_LOCAL_RANK <= 0:
  128. return func(*args, **kwargs)
  129. else:
  130. return do_nothing(*args, **kwargs)
  131. return wrapper
  132. def find_free_port() -> int:
  133. """Find an available port of current machine/node.
  134. Note: there is still a chance the port could be taken by other processes."""
  135. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
  136. # Binding to port 0 will cause the OS to find an available port for us
  137. sock.bind(("", 0))
  138. _ip, port = sock.getsockname()
  139. return port
Discard
@@ -7,9 +7,3 @@ try:
 except Exception:
 except Exception:
     os.makedirs(os.path.join(os.getcwd(), "checkpoints"), exist_ok=True)
     os.makedirs(os.path.join(os.getcwd(), "checkpoints"), exist_ok=True)
     PKG_CHECKPOINTS_DIR = os.path.join(os.getcwd(), "checkpoints")
     PKG_CHECKPOINTS_DIR = os.path.join(os.getcwd(), "checkpoints")
-
-
-DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=-1))
-EXTRA_ARGS = []
-
-INIT_TRAINER = False
Discard
@@ -1,7 +1,7 @@
 import time
 import time
 import threading
 import threading
 
 
-from super_gradients.common.environment.env_helpers import multi_process_safe
+from super_gradients.common.environment.ddp_utils import multi_process_safe
 from super_gradients.common.environment.monitoring import disk, virtual_memory, network, cpu, gpu
 from super_gradients.common.environment.monitoring import disk, virtual_memory, network, cpu, gpu
 from super_gradients.common.environment.monitoring.utils import average, delta_per_s
 from super_gradients.common.environment.monitoring.utils import average, delta_per_s
 from super_gradients.common.environment.monitoring.data_models import StatAggregator, GPUStatAggregatorIterator
 from super_gradients.common.environment.monitoring.data_models import StatAggregator, GPUStatAggregatorIterator
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  1. import importlib
  2. import os
  3. import sys
  4. from omegaconf import OmegaConf
  5. from super_gradients.common.environment import environment_config
  6. def get_cls(cls_path: str):
  7. """
  8. A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.
  9. usage:
  10. class_of_optimizer: ${class:torch.optim.Adam}
  11. """
  12. module = ".".join(cls_path.split(".")[:-1])
  13. name = cls_path.split(".")[-1]
  14. importlib.import_module(module)
  15. return getattr(sys.modules[module], name)
  16. def hydra_output_dir_resolver(ckpt_root_dir: str, experiment_name: str) -> str:
  17. if ckpt_root_dir is None:
  18. output_dir_path = environment_config.PKG_CHECKPOINTS_DIR + os.path.sep + experiment_name
  19. else:
  20. output_dir_path = ckpt_root_dir + os.path.sep + experiment_name
  21. return output_dir_path
  22. def register_hydra_resolvers():
  23. """Register all the hydra resolvers required for the super-gradients recipes."""
  24. OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True)
  25. OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True)
  26. OmegaConf.register_new_resolver("add", lambda *args: sum(args), replace=True)
  27. OmegaConf.register_new_resolver("cond", lambda boolean, x, y: x if boolean else y, replace=True)
  28. OmegaConf.register_new_resolver("getitem", lambda container, key: container[key], replace=True) # get item from a container (list, dict...)
  29. OmegaConf.register_new_resolver("first", lambda lst: lst[0], replace=True) # get the first item from a list
  30. OmegaConf.register_new_resolver("last", lambda lst: lst[-1], replace=True) # get the last item from a list
Discard
@@ -12,7 +12,7 @@ from PIL import Image
 from super_gradients.common import ADNNModelRepositoryDataInterfaces
 from super_gradients.common import ADNNModelRepositoryDataInterfaces
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.auto_logging import AutoLoggerConfig, ConsoleSink
 from super_gradients.common.auto_logging import AutoLoggerConfig, ConsoleSink
-from super_gradients.common.environment.env_helpers import multi_process_safe
+from super_gradients.common.environment.ddp_utils import multi_process_safe
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.training.params import TrainingParams
 from super_gradients.training.params import TrainingParams
 from super_gradients.training.utils import sg_trainer_utils
 from super_gradients.training.utils import sg_trainer_utils
Discard
@@ -9,7 +9,7 @@ import torch
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
-from super_gradients.common.environment.env_helpers import multi_process_safe
+from super_gradients.common.environment.ddp_utils import multi_process_safe
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
Discard
@@ -2,7 +2,7 @@ import os
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
-from super_gradients.common.environment.env_helpers import multi_process_safe
+from super_gradients.common.environment.ddp_utils import multi_process_safe
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
Discard
@@ -9,7 +9,7 @@ import torch
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 
 
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
-from super_gradients.common.environment.env_helpers import multi_process_safe
+from super_gradients.common.environment.ddp_utils import multi_process_safe
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
Discard
@@ -20,7 +20,7 @@ Example: python evaluate_checkpoint.py --experiment_name=my_experiment_name --ck
 
 
 """
 """
 from super_gradients import Trainer, init_trainer
 from super_gradients import Trainer, init_trainer
-from super_gradients.common.environment.env_helpers import pop_arg
+from super_gradients.common.environment.ddp_utils import pop_arg
 
 
 
 
 def main() -> None:
 def main() -> None:
Discard
@@ -4,7 +4,7 @@ Example code for resuming SuperGradient's recipes.
 General use: python resume_experiment.py --experiment_name=<PREVIOUSLY-RUN-EXPERIMENT>
 General use: python resume_experiment.py --experiment_name=<PREVIOUSLY-RUN-EXPERIMENT>
 """
 """
 from super_gradients import Trainer, init_trainer
 from super_gradients import Trainer, init_trainer
-from super_gradients.common.environment.env_helpers import pop_arg
+from super_gradients.common.environment.ddp_utils import pop_arg
 
 
 
 
 def main() -> None:
 def main() -> None:
Discard
@@ -7,7 +7,7 @@ from pathlib import Path
 from packaging.version import Version
 from packaging.version import Version
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.common.environment.env_helpers import is_main_process
+from super_gradients.common.environment.ddp_utils import is_main_process
 
 
 LIB_CHECK_IMPOSSIBLE_MSG = 'Library check is not supported when super_gradients installed through "git+https://github.com/..." command'
 LIB_CHECK_IMPOSSIBLE_MSG = 'Library check is not supported when super_gradients installed through "git+https://github.com/..." command'
 
 
Discard
@@ -24,7 +24,7 @@ from super_gradients.common.factories.callbacks_factory import CallbacksFactory
 from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
 from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.environment import env_helpers
+from super_gradients.common.environment import ddp_utils
 from super_gradients.common.abstractions.abstract_logger import get_logger, mute_current_process
 from super_gradients.common.abstractions.abstract_logger import get_logger, mute_current_process
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.losses_factory import LossesFactory
 from super_gradients.common.factories.losses_factory import LossesFactory
@@ -81,7 +81,6 @@ from super_gradients.training.utils.callbacks import (
     ContextSgMethods,
     ContextSgMethods,
     LRCallbackBase,
     LRCallbackBase,
 )
 )
-from super_gradients.common.environment import environment_config
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
 from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
@@ -1433,7 +1432,7 @@ class Trainer:
                         logger.warning("\n[WARNING] - Tried running on multiple GPU but only a single GPU is available\n")
                         logger.warning("\n[WARNING] - Tried running on multiple GPU but only a single GPU is available\n")
                 else:
                 else:
                     if requested_multi_gpu == MultiGPUMode.AUTO:
                     if requested_multi_gpu == MultiGPUMode.AUTO:
-                        if env_helpers.is_distributed():
+                        if ddp_utils.is_distributed():
                             requested_multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
                             requested_multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
                         else:
                         else:
                             requested_multi_gpu = MultiGPUMode.DATA_PARALLEL
                             requested_multi_gpu = MultiGPUMode.DATA_PARALLEL
@@ -1456,7 +1455,7 @@ class Trainer:
         batch you specify times the number of GPUs. In the literature there are several "best practices" to set
         batch you specify times the number of GPUs. In the literature there are several "best practices" to set
         learning rates and schedules for large batch sizes.
         learning rates and schedules for large batch sizes.
         """
         """
-        local_rank = environment_config.DDP_LOCAL_RANK
+        local_rank = ddp_utils.DDP_LOCAL_RANK
         if local_rank > 0:
         if local_rank > 0:
             mute_current_process()
             mute_current_process()
 
 
Discard
@@ -11,9 +11,10 @@ from torch.distributed.elastic.multiprocessing.errors import record
 from torch.distributed.launcher.api import LaunchConfig, elastic_launch
 from torch.distributed.launcher.api import LaunchConfig, elastic_launch
 
 
 from super_gradients.common.data_types.enum import MultiGPUMode
 from super_gradients.common.data_types.enum import MultiGPUMode
-from super_gradients.common.environment.env_helpers import find_free_port, is_distributed
+from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
+from super_gradients.common.environment.ddp_utils import find_free_port, is_distributed
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.common.environment import environment_config
+
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -238,7 +239,7 @@ def restart_script_with_ddp(num_gpus: int = None):
         metrics_cfg={},
         metrics_cfg={},
     )
     )
 
 
-    elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *environment_config.EXTRA_ARGS)
+    elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *EXTRA_ARGS)
 
 
     # The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
     # The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
     sys.exit("Main process finished")
     sys.exit("Main process finished")
Discard