Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distributed_training_utils.py 5.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
  1. import torch
  2. from torch import distributed as dist
  3. from torch.cuda.amp import autocast
  4. import torch.nn as nn
  5. import itertools
  6. from contextlib import contextmanager
  7. def distributed_all_reduce_tensor_average(tensor, n):
  8. """
  9. This method performs a reduce operation on multiple nodes running distributed training
  10. It first sums all of the results and then divides the summation
  11. :param tensor: The tensor to perform the reduce operation for
  12. :param n: Number of nodes
  13. :return: Averaged tensor from all of the nodes
  14. """
  15. rt = tensor.clone()
  16. torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
  17. rt /= n
  18. return rt
  19. def reduce_results_tuple_for_ddp(validation_results_tuple, device):
  20. """Gather all validation tuples from the various devices and average them"""
  21. validation_results_list = list(validation_results_tuple)
  22. for i, validation_result in enumerate(validation_results_list):
  23. validation_results_list[i] = distributed_all_reduce_tensor_average(torch.tensor(validation_result).to(device),
  24. torch.distributed.get_world_size())
  25. validation_results_tuple = tuple(validation_results_list)
  26. return validation_results_tuple
  27. class MultiGPUModeAutocastWrapper():
  28. def __init__(self, func):
  29. self.func = func
  30. def __call__(self, *args, **kwargs):
  31. with autocast():
  32. out = self.func(*args, **kwargs)
  33. return out
  34. def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
  35. """
  36. Performs the scaled all_reduce operation on the provided tensors.
  37. The input tensors are modified in-place.
  38. Currently supports only the sum
  39. reduction operator.
  40. The reduced values are scaled by the inverse size of the
  41. process group (equivalent to num_gpus).
  42. """
  43. # There is no need for reduction in the single-proc case
  44. if num_gpus == 1:
  45. return tensors
  46. # Queue the reductions
  47. reductions = []
  48. for tensor in tensors:
  49. reduction = torch.distributed.all_reduce(tensor, async_op=True)
  50. reductions.append(reduction)
  51. # Wait for reductions to finish
  52. for reduction in reductions:
  53. reduction.wait()
  54. # Scale the results
  55. for tensor in tensors:
  56. tensor.mul_(1.0 / num_gpus)
  57. return tensors
  58. @torch.no_grad()
  59. def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
  60. '''
  61. :param model: The model being trained (ie: SgModel.net)
  62. :param loader: Training dataloader (ie: SgModel.train_loader)
  63. :param precise_bn_batch_size: The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
  64. on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
  65. (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
  66. If precise_bn_batch_size is not provided in the training_params, the latter heuristic
  67. will be taken.
  68. param num_gpus: The number of gpus we are training on
  69. '''
  70. # Compute the number of minibatches to use
  71. num_iter = int(precise_bn_batch_size / (loader.batch_size * num_gpus)) if precise_bn_batch_size else num_gpus
  72. num_iter = min(num_iter, len(loader))
  73. # Retrieve the BN layers
  74. bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
  75. # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
  76. running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
  77. running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
  78. # Remember momentum values
  79. momentums = [bn.momentum for bn in bns]
  80. # Set momentum to 1.0 to compute BN stats that only reflect the current batch
  81. for bn in bns:
  82. bn.momentum = 1.0
  83. # Average the BN stats for each BN layer over the batches
  84. for inputs, _labels in itertools.islice(loader, num_iter):
  85. model(inputs.cuda())
  86. for i, bn in enumerate(bns):
  87. running_means[i] += bn.running_mean / num_iter
  88. running_vars[i] += bn.running_var / num_iter
  89. # Sync BN stats across GPUs (no reduction if 1 GPU used)
  90. running_means = scaled_all_reduce(running_means, num_gpus=num_gpus)
  91. running_vars = scaled_all_reduce(running_vars, num_gpus=num_gpus)
  92. # Set BN stats and restore original momentum values
  93. for i, bn in enumerate(bns):
  94. bn.running_mean = running_means[i]
  95. bn.running_var = running_vars[i]
  96. bn.momentum = momentums[i]
  97. def get_local_rank():
  98. """
  99. Returns the local rank if running in DDP, and 0 otherwise
  100. :return: local rank
  101. """
  102. return dist.get_rank() if dist.is_initialized() else 0
  103. def get_world_size() -> int:
  104. """
  105. Returns the world size if running in DDP, and 1 otherwise
  106. :return: world size
  107. """
  108. if not dist.is_available():
  109. return 1
  110. if not dist.is_initialized():
  111. return 1
  112. return dist.get_world_size()
  113. @contextmanager
  114. def wait_for_the_master(local_rank: int):
  115. """
  116. Make all processes waiting for the master to do some task.
  117. """
  118. if local_rank > 0:
  119. dist.barrier()
  120. yield
  121. if local_rank == 0:
  122. if not dist.is_available():
  123. return
  124. if not dist.is_initialized():
  125. return
  126. else:
  127. dist.barrier()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...