Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distributed_utils.py 5.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. from collections import namedtuple
  8. import os
  9. import pickle
  10. import subprocess
  11. import torch
  12. import torch.distributed as dist
  13. from torch import nn
  14. from fairseq import utils
  15. def is_master(args):
  16. return args.distributed_rank == 0
  17. def infer_init_method(args):
  18. if args.distributed_init_method is not None:
  19. return
  20. # support torch.distributed.launch
  21. if all(key in os.environ for key in [
  22. 'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK'
  23. ]):
  24. args.distributed_init_method = 'tcp://{addr}:{port}'.format(
  25. addr=os.environ['MASTER_ADDR'],
  26. port=os.environ['MASTER_PORT'],
  27. )
  28. args.distributed_world_size = int(os.environ['WORLD_SIZE'])
  29. args.distributed_rank = int(os.environ['RANK'])
  30. # we can determine the init method automatically for Slurm
  31. elif args.distributed_port > 0:
  32. node_list = os.environ.get('SLURM_JOB_NODELIST')
  33. if node_list is not None:
  34. try:
  35. hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
  36. args.distributed_init_method = 'tcp://{host}:{port}'.format(
  37. host=hostnames.split()[0].decode('utf-8'),
  38. port=args.distributed_port)
  39. args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
  40. args.device_id = int(os.environ.get('SLURM_LOCALID'))
  41. except subprocess.CalledProcessError as e: # scontrol failed
  42. raise e
  43. except FileNotFoundError: # Slurm is not installed
  44. pass
  45. def distributed_init(args):
  46. if args.distributed_world_size == 1:
  47. raise ValueError('Cannot initialize distributed with distributed_world_size=1')
  48. print('| distributed init (rank {}): {}'.format(
  49. args.distributed_rank, args.distributed_init_method), flush=True)
  50. dist.init_process_group(
  51. backend=args.distributed_backend,
  52. init_method=args.distributed_init_method,
  53. world_size=args.distributed_world_size,
  54. rank=args.distributed_rank,
  55. )
  56. suppress_output(is_master(args))
  57. return args.distributed_rank
  58. def suppress_output(is_master):
  59. """Suppress printing on the current device. Force printing with `force=True`."""
  60. import builtins as __builtin__
  61. builtin_print = __builtin__.print
  62. def print(*args, **kwargs):
  63. force = kwargs.pop('force', False)
  64. if is_master or force:
  65. builtin_print(*args, **kwargs)
  66. __builtin__.print = print
  67. def get_rank():
  68. return dist.get_rank()
  69. def get_world_size():
  70. return dist.get_world_size()
  71. def get_default_group():
  72. return dist.group.WORLD
  73. def all_reduce(tensor, group=None):
  74. if group is None:
  75. group = get_default_group()
  76. return dist.all_reduce(tensor, group=group)
  77. def all_gather_list(data, group=None, max_size=16384):
  78. """Gathers arbitrary data from all nodes into a list.
  79. Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
  80. data. Note that *data* must be picklable.
  81. Args:
  82. data (Any): data from the local worker to be gathered on other workers
  83. group (optional): group of the collective
  84. max_size (int, optional): maximum size of the data to be gathered
  85. across workers
  86. """
  87. rank = get_rank()
  88. world_size = get_world_size()
  89. buffer_size = max_size * world_size
  90. if not hasattr(all_gather_list, '_buffer') or \
  91. all_gather_list._buffer.numel() < buffer_size:
  92. all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
  93. buffer = all_gather_list._buffer
  94. buffer.zero_()
  95. enc = pickle.dumps(data)
  96. enc_size = len(enc)
  97. if enc_size + 2 > max_size:
  98. raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
  99. assert max_size < 255*256
  100. buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
  101. buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
  102. buffer_rank[1] = enc_size % 255
  103. buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))
  104. all_reduce(buffer, group=group)
  105. try:
  106. result = []
  107. for i in range(world_size):
  108. out_buffer = buffer[i * max_size : (i + 1) * max_size]
  109. size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
  110. if size > 0:
  111. result.append(
  112. pickle.loads(bytes(out_buffer[2:size+2].tolist()))
  113. )
  114. return result
  115. except pickle.UnpicklingError:
  116. raise Exception(
  117. 'Unable to unpickle data from other workers. all_gather_list requires all '
  118. 'workers to enter the function together, so this error usually indicates '
  119. 'that the workers have fallen out of sync somehow. Workers can fall out of '
  120. 'sync if one of them runs out of memory, or if there are other conditions '
  121. 'in your training script that can cause one worker to finish an epoch '
  122. 'while other workers are still iterating over their portions of the data.'
  123. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...