|
@@ -11,9 +11,10 @@ from torch.distributed.elastic.multiprocessing.errors import record
|
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
|
|
|
|
from super_gradients.common.data_types.enum import MultiGPUMode
|
|
from super_gradients.common.data_types.enum import MultiGPUMode
|
|
-from super_gradients.common.environment.env_helpers import find_free_port, is_distributed
|
|
|
|
|
|
+from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
|
|
|
|
+from super_gradients.common.environment.ddp_utils import find_free_port, is_distributed
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
-from super_gradients.common.environment import environment_config
|
|
|
|
|
|
+
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -238,7 +239,7 @@ def restart_script_with_ddp(num_gpus: int = None):
|
|
metrics_cfg={},
|
|
metrics_cfg={},
|
|
)
|
|
)
|
|
|
|
|
|
- elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *environment_config.EXTRA_ARGS)
|
|
|
|
|
|
+ elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *EXTRA_ARGS)
|
|
|
|
|
|
# The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
|
|
# The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
|
|
sys.exit("Main process finished")
|
|
sys.exit("Main process finished")
|