Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  1. # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
  2. import os
  3. import shutil
  4. import socket
  5. import sys
  6. import tempfile
  7. from . import USER_CONFIG_DIR
  8. from .torch_utils import TORCH_1_9
  9. def find_free_network_port() -> int:
  10. """
  11. Find a free port on localhost.
  12. It is useful in single-node training when we don't want to connect to a real main node but have to set the
  13. `MASTER_PORT` environment variable.
  14. Returns:
  15. (int): The available network port number.
  16. """
  17. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  18. s.bind(("127.0.0.1", 0))
  19. return s.getsockname()[1] # port
  20. def generate_ddp_file(trainer):
  21. """
  22. Generate a DDP (Distributed Data Parallel) file for multi-GPU training.
  23. This function creates a temporary Python file that enables distributed training across multiple GPUs.
  24. The file contains the necessary configuration to initialize the trainer in a distributed environment.
  25. Args:
  26. trainer (object): The trainer object containing training configuration and arguments.
  27. Must have args attribute and be a class instance.
  28. Returns:
  29. (str): Path to the generated temporary DDP file.
  30. Notes:
  31. The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:
  32. - Trainer class import
  33. - Configuration overrides from the trainer arguments
  34. - Model path configuration
  35. - Training initialization code
  36. """
  37. module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
  38. content = f"""
  39. # Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
  40. overrides = {vars(trainer.args)}
  41. if __name__ == "__main__":
  42. from {module} import {name}
  43. from ultralytics.utils import DEFAULT_CFG_DICT
  44. cfg = DEFAULT_CFG_DICT.copy()
  45. cfg.update(save_dir='') # handle the extra key 'save_dir'
  46. trainer = {name}(cfg=cfg, overrides=overrides)
  47. trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}"
  48. results = trainer.train()
  49. """
  50. (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
  51. with tempfile.NamedTemporaryFile(
  52. prefix="_temp_",
  53. suffix=f"{id(trainer)}.py",
  54. mode="w+",
  55. encoding="utf-8",
  56. dir=USER_CONFIG_DIR / "DDP",
  57. delete=False,
  58. ) as file:
  59. file.write(content)
  60. return file.name
  61. def generate_ddp_command(world_size, trainer):
  62. """
  63. Generate command for distributed training.
  64. Args:
  65. world_size (int): Number of processes to spawn for distributed training.
  66. trainer (object): The trainer object containing configuration for distributed training.
  67. Returns:
  68. cmd (List[str]): The command to execute for distributed training.
  69. file (str): Path to the temporary file created for DDP training.
  70. """
  71. import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218
  72. if not trainer.resume:
  73. shutil.rmtree(trainer.save_dir) # remove the save_dir
  74. file = generate_ddp_file(trainer)
  75. dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
  76. port = find_free_network_port()
  77. cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
  78. return cmd, file
  79. def ddp_cleanup(trainer, file):
  80. """
  81. Delete temporary file if created during distributed data parallel (DDP) training.
  82. This function checks if the provided file contains the trainer's ID in its name, indicating it was created
  83. as a temporary file for DDP training, and deletes it if so.
  84. Args:
  85. trainer (object): The trainer object used for distributed training.
  86. file (str): Path to the file that might need to be deleted.
  87. Examples:
  88. >>> trainer = YOLOTrainer()
  89. >>> file = "/tmp/ddp_temp_123456789.py"
  90. >>> ddp_cleanup(trainer, file)
  91. """
  92. if f"{id(trainer)}.py" in file: # if temp_file suffix in file
  93. os.remove(file)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file