Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#970 Update YoloNASQuickstart.md

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_readme_yolonas_snippets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
  1. import sys
  2. import os
  3. import itertools
  4. from typing import List, Tuple
  5. from contextlib import contextmanager
  6. import torch
  7. import torch.nn as nn
  8. from torch import distributed as dist
  9. from torch.cuda.amp import autocast
  10. from torch.distributed.elastic.multiprocessing import Std
  11. from torch.distributed.elastic.multiprocessing.errors import record
  12. from torch.distributed.launcher.api import LaunchConfig, elastic_launch
  13. from super_gradients.common.environment.ddp_utils import init_trainer
  14. from super_gradients.common.data_types.enum import MultiGPUMode
  15. from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
  16. from super_gradients.common.environment.ddp_utils import find_free_port, is_distributed, is_launched_using_sg
  17. from super_gradients.common.abstractions.abstract_logger import get_logger
  18. from super_gradients.common.abstractions.mute_processes import mute_current_process
  19. from super_gradients.common.environment.device_utils import device_config
  20. from super_gradients.common.decorators.factory_decorator import resolve_param
  21. from super_gradients.common.factories.type_factory import TypeFactory
  22. logger = get_logger(__name__)
  23. def distributed_all_reduce_tensor_average(tensor, n):
  24. """
  25. This method performs a reduce operation on multiple nodes running distributed training
  26. It first sums all of the results and then divides the summation
  27. :param tensor: The tensor to perform the reduce operation for
  28. :param n: Number of nodes
  29. :return: Averaged tensor from all of the nodes
  30. """
  31. rt = tensor.clone()
  32. torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
  33. rt /= n
  34. return rt
  35. def reduce_results_tuple_for_ddp(validation_results_tuple, device):
  36. """Gather all validation tuples from the various devices and average them"""
  37. validation_results_list = list(validation_results_tuple)
  38. for i, validation_result in enumerate(validation_results_list):
  39. if torch.is_tensor(validation_result):
  40. validation_result = validation_result.clone().detach()
  41. else:
  42. validation_result = torch.tensor(validation_result)
  43. validation_results_list[i] = distributed_all_reduce_tensor_average(tensor=validation_result.to(device), n=torch.distributed.get_world_size())
  44. validation_results_tuple = tuple(validation_results_list)
  45. return validation_results_tuple
  46. class MultiGPUModeAutocastWrapper:
  47. def __init__(self, func):
  48. self.func = func
  49. def __call__(self, *args, **kwargs):
  50. with autocast():
  51. out = self.func(*args, **kwargs)
  52. return out
  53. def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
  54. """
  55. Performs the scaled all_reduce operation on the provided tensors.
  56. The input tensors are modified in-place.
  57. Currently supports only the sum
  58. reduction operator.
  59. The reduced values are scaled by the inverse size of the
  60. process group (equivalent to num_gpus).
  61. """
  62. # There is no need for reduction in the single-proc case
  63. if num_gpus == 1:
  64. return tensors
  65. # Queue the reductions
  66. reductions = []
  67. for tensor in tensors:
  68. reduction = torch.distributed.all_reduce(tensor, async_op=True)
  69. reductions.append(reduction)
  70. # Wait for reductions to finish
  71. for reduction in reductions:
  72. reduction.wait()
  73. # Scale the results
  74. for tensor in tensors:
  75. tensor.mul_(1.0 / num_gpus)
  76. return tensors
  77. @torch.no_grad()
  78. def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
  79. """
  80. :param model: The model being trained (ie: Trainer.net)
  81. :param loader: Training dataloader (ie: Trainer.train_loader)
  82. :param precise_bn_batch_size: The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
  83. on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
  84. (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
  85. If precise_bn_batch_size is not provided in the training_params, the latter heuristic
  86. will be taken.
  87. param num_gpus: The number of gpus we are training on
  88. """
  89. # Compute the number of minibatches to use
  90. num_iter = int(precise_bn_batch_size / (loader.batch_size * num_gpus)) if precise_bn_batch_size else num_gpus
  91. num_iter = min(num_iter, len(loader))
  92. # Retrieve the BN layers
  93. bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
  94. # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
  95. running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
  96. running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
  97. # Remember momentum values
  98. momentums = [bn.momentum for bn in bns]
  99. # Set momentum to 1.0 to compute BN stats that only reflect the current batch
  100. for bn in bns:
  101. bn.momentum = 1.0
  102. # Average the BN stats for each BN layer over the batches
  103. for inputs, _labels in itertools.islice(loader, num_iter):
  104. model(inputs.cuda())
  105. for i, bn in enumerate(bns):
  106. running_means[i] += bn.running_mean / num_iter
  107. running_vars[i] += bn.running_var / num_iter
  108. # Sync BN stats across GPUs (no reduction if 1 GPU used)
  109. running_means = scaled_all_reduce(running_means, num_gpus=num_gpus)
  110. running_vars = scaled_all_reduce(running_vars, num_gpus=num_gpus)
  111. # Set BN stats and restore original momentum values
  112. for i, bn in enumerate(bns):
  113. bn.running_mean = running_means[i]
  114. bn.running_var = running_vars[i]
  115. bn.momentum = momentums[i]
  116. def get_local_rank():
  117. """
  118. Returns the local rank if running in DDP, and 0 otherwise
  119. :return: local rank
  120. """
  121. return dist.get_rank() if dist.is_initialized() else 0
  122. def require_ddp_setup() -> bool:
  123. return device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and device_config.assigned_rank != get_local_rank()
  124. def is_ddp_subprocess():
  125. return torch.distributed.get_rank() > 0 if dist.is_initialized() else False
  126. def get_world_size() -> int:
  127. """
  128. Returns the world size if running in DDP, and 1 otherwise
  129. :return: world size
  130. """
  131. if not dist.is_available():
  132. return 1
  133. if not dist.is_initialized():
  134. return 1
  135. return dist.get_world_size()
  136. def get_device_ids() -> List[int]:
  137. return list(range(get_world_size()))
  138. def count_used_devices() -> int:
  139. return len(get_device_ids())
  140. @contextmanager
  141. def wait_for_the_master(local_rank: int):
  142. """
  143. Make all processes waiting for the master to do some task.
  144. """
  145. if local_rank > 0:
  146. dist.barrier()
  147. yield
  148. if local_rank == 0:
  149. if not dist.is_available():
  150. return
  151. if not dist.is_initialized():
  152. return
  153. else:
  154. dist.barrier()
  155. def setup_gpu_mode(gpu_mode: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
  156. """[DEPRECATED in favor of setup_device] If required, launch ddp subprocesses.
  157. :param gpu_mode: DDP, DP, Off or AUTO
  158. :param num_gpus: Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
  159. """
  160. logger.warning("setup_gpu_mode is now deprecated in favor of setup_device")
  161. setup_device(multi_gpu=gpu_mode, num_gpus=num_gpus)
  162. @resolve_param("multi_gpu", TypeFactory(MultiGPUMode.dict()))
  163. def setup_device(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None, device: str = "cuda"):
  164. """
  165. If required, launch ddp subprocesses.
  166. :param multi_gpu: DDP, DP, Off or AUTO
  167. :param num_gpus: Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
  168. :param device: The device you want to use ('cpu' or 'cuda')
  169. If you only set num_gpus, your device will be set up according to the following logic:
  170. - `setup_device(num_gpus=0)` => `gpu_mode='OFF'` and `device='cpu'`
  171. - `setup_device(num_gpus=1)` => `gpu_mode='OFF'` and `device='gpu'`
  172. - `setup_device(num_gpus>=2)` => `gpu_mode='DDP'` and `device='gpu'`
  173. - `setup_device(num_gpus=-1)` => `gpu_mode='DDP'` and `device='gpu'` and `num_gpus=<N-AVAILABLE-GPUs>`
  174. """
  175. init_trainer()
  176. # When launching with torch.distributed.launch or torchrun, multi_gpu might not be set to DDP (since we are not using the recipe params)
  177. # To avoid any issue we force multi_gpu to be DDP if the current process is ddp subprocess. We also set num_gpus, device to run smoothly.
  178. if not is_launched_using_sg() and is_distributed():
  179. multi_gpu, num_gpus, device = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, None, "cuda"
  180. if device is None:
  181. device = "cuda"
  182. if device == "cuda" and not torch.cuda.is_available():
  183. logger.warning("CUDA device is not available on your device... Moving to CPU.")
  184. multi_gpu, num_gpus, device = MultiGPUMode.OFF, 0, "cpu"
  185. if device == "cpu":
  186. setup_cpu(multi_gpu, num_gpus)
  187. elif device == "cuda":
  188. setup_gpu(multi_gpu, num_gpus)
  189. else:
  190. raise ValueError(f"Only valid values for device are: 'cpu' and 'cuda'. Received: '{device}'")
  191. def setup_cpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
  192. """
  193. :param multi_gpu: DDP, DP, Off or AUTO
  194. :param num_gpus: Number of GPU's to use.
  195. """
  196. if multi_gpu not in (MultiGPUMode.OFF, MultiGPUMode.AUTO):
  197. raise ValueError(f"device='cpu' and multi_gpu={multi_gpu} are not compatible together.")
  198. if num_gpus not in (0, None):
  199. raise ValueError(f"device='cpu' and num_gpus={num_gpus} are not compatible together.")
  200. device_config.device = "cpu"
  201. device_config.multi_gpu = MultiGPUMode.OFF
  202. def setup_gpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
  203. """
  204. If required, launch ddp subprocesses.
  205. :param multi_gpu: DDP, DP, Off or AUTO
  206. :param num_gpus: Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
  207. """
  208. if num_gpus == 0:
  209. raise ValueError("device='cuda' and num_gpus=0 are not compatible together.")
  210. multi_gpu, num_gpus = _resolve_gpu_params(multi_gpu=multi_gpu, num_gpus=num_gpus)
  211. device_config.device = "cuda"
  212. device_config.multi_gpu = multi_gpu
  213. if is_distributed():
  214. initialize_ddp()
  215. elif multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
  216. restart_script_with_ddp(num_gpus=num_gpus)
  217. def _resolve_gpu_params(multi_gpu: MultiGPUMode, num_gpus: int) -> Tuple[MultiGPUMode, int]:
  218. """
  219. Resolve the values multi_gpu in (None, MultiGPUMode.AUTO) and num_gpus in (None, -1), and check compatibility between both parameters.
  220. :param multi_gpu: DDP, DP, Off or AUTO
  221. :param num_gpus: Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
  222. """
  223. # Resolve None
  224. if multi_gpu is None:
  225. if num_gpus is None: # When Nothing is specified, just run on single GPU
  226. multi_gpu = MultiGPUMode.OFF
  227. num_gpus = 1
  228. else:
  229. multi_gpu = MultiGPUMode.AUTO
  230. if num_gpus is None:
  231. num_gpus = -1
  232. # Resolve multi_gpu
  233. if num_gpus == -1:
  234. if multi_gpu in (MultiGPUMode.OFF, MultiGPUMode.DATA_PARALLEL):
  235. num_gpus = 1
  236. elif multi_gpu in (MultiGPUMode.AUTO, MultiGPUMode.DISTRIBUTED_DATA_PARALLEL):
  237. num_gpus = torch.cuda.device_count()
  238. # Resolve multi_gpu
  239. if multi_gpu == MultiGPUMode.AUTO:
  240. if num_gpus > 1:
  241. multi_gpu = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL
  242. else:
  243. multi_gpu = MultiGPUMode.OFF
  244. # Check compatibility between num_gpus and multi_gpu
  245. if multi_gpu in (MultiGPUMode.OFF, MultiGPUMode.DATA_PARALLEL):
  246. if num_gpus != 1:
  247. raise ValueError(f"You specified num_gpus={num_gpus} but it has not be 1 on when working with multi_gpu={multi_gpu}")
  248. else:
  249. if num_gpus > torch.cuda.device_count():
  250. raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
  251. return multi_gpu, num_gpus
  252. def initialize_ddp():
  253. """
  254. Initialize Distributed Data Parallel
  255. Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
  256. Whatever learning rate and schedule you specify will be applied to the each GPU individually.
  257. Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
  258. batch you specify times the number of GPUs. In the literature there are several "best practices" to set
  259. learning rates and schedules for large batch sizes.
  260. """
  261. if device_config.assigned_rank > 0:
  262. mute_current_process()
  263. logger.info("Distributed training starting...")
  264. if not torch.distributed.is_initialized():
  265. backend = "gloo" if os.name == "nt" else "nccl"
  266. torch.distributed.init_process_group(backend=backend, init_method="env://")
  267. torch.cuda.set_device(device_config.assigned_rank)
  268. if torch.distributed.get_rank() == 0:
  269. logger.info(f"Training in distributed mode... with {str(torch.distributed.get_world_size())} GPUs")
  270. device_config.device = "cuda:%d" % device_config.assigned_rank
  271. @record
  272. def restart_script_with_ddp(num_gpus: int = None):
  273. """Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).
  274. :param num_gpus: How many gpu's you want to run the script on. If not specified, every available device will be used.
  275. """
  276. ddp_port = find_free_port()
  277. # Get the value fom recipe if specified, otherwise take all available devices.
  278. num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
  279. if num_gpus > torch.cuda.device_count():
  280. raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
  281. logger.info(
  282. "Launching DDP with:\n"
  283. f" - ddp_port = {ddp_port}\n"
  284. f" - num_gpus = {num_gpus}/{torch.cuda.device_count()} available\n"
  285. "-------------------------------------\n"
  286. )
  287. config = LaunchConfig(
  288. nproc_per_node=num_gpus,
  289. min_nodes=1,
  290. max_nodes=1,
  291. run_id="sg_initiated",
  292. role="default",
  293. rdzv_endpoint=f"127.0.0.1:{ddp_port}",
  294. rdzv_backend="static",
  295. rdzv_configs={"rank": 0, "timeout": 900},
  296. rdzv_timeout=-1,
  297. max_restarts=0,
  298. monitor_interval=5,
  299. start_method="spawn",
  300. log_dir=None,
  301. redirects=Std.NONE,
  302. tee=Std.NONE,
  303. metrics_cfg={},
  304. )
  305. elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *EXTRA_ARGS)
  306. # The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
  307. sys.exit(0)
  308. def get_gpu_mem_utilization():
  309. """GPU memory managed by the caching allocator in bytes for a given device."""
  310. # Workaround to work on any torch version
  311. if hasattr(torch.cuda, "memory_reserved"):
  312. return torch.cuda.memory_reserved()
  313. else:
  314. return torch.cuda.memory_cached()
  315. class DDPNotSetupException(Exception):
  316. """Exception raised when DDP setup is required but was not done"""
  317. def __init__(self):
  318. self.message = (
  319. "Your environment was not setup correctly for DDP.\n"
  320. "Please run at the beginning of your script:\n"
  321. ">>> from super_gradients.training.utils.distributed_training_utils import setup_device'\n"
  322. ">>> from super_gradients.common.data_types.enum import MultiGPUMode\n"
  323. ">>> setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)"
  324. )
  325. super().__init__(self.message)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file