Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

distributed_fairseq_model.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import inspect
  8. from torch.nn import parallel
  9. from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
  10. from . import BaseFairseqModel
  11. def DistributedFairseqModel(args, model):
  12. """
  13. Wrap a *model* to support distributed data parallel training.
  14. This is similar to the built-in DistributedDataParallel, but allows
  15. additional configuration of the DistributedDataParallel class to
  16. use, and also provides easier access to the wrapped model by
  17. forwarding requests for missing attributes to the wrapped model.
  18. Args:
  19. args (argparse.Namespace): fairseq args
  20. model (BaseFairseqModel): model to wrap
  21. """
  22. # determine which DDP class to extend
  23. assert isinstance(model, BaseFairseqModel)
  24. if args.ddp_backend == 'c10d':
  25. ddp_class = parallel.DistributedDataParallel
  26. init_kwargs = dict(
  27. module=model,
  28. device_ids=[args.device_id],
  29. output_device=args.device_id,
  30. broadcast_buffers=False,
  31. bucket_cap_mb=args.bucket_cap_mb,
  32. )
  33. # Maintain backward compatibility for 0.4 or earlier
  34. if 'check_reduction' in inspect.getargspec(ddp_class)[0]:
  35. init_kwargs['check_reduction'] = True
  36. elif args.ddp_backend == 'no_c10d':
  37. ddp_class = LegacyDistributedDataParallel
  38. init_kwargs = dict(
  39. module=model,
  40. world_size=args.distributed_world_size,
  41. buffer_size=2**28,
  42. )
  43. else:
  44. raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
  45. class _DistributedFairseqModel(ddp_class):
  46. """Extend DistributedDataParallel to check for missing
  47. attributes in the wrapped module."""
  48. def __init__(self, *args, **kwargs):
  49. super().__init__(*args, **kwargs)
  50. def __getattr__(self, name):
  51. wrapped_module = super().__getattr__('module')
  52. if hasattr(wrapped_module, name):
  53. return getattr(wrapped_module, name)
  54. return super().__getattr__(name)
  55. return _DistributedFairseqModel(**init_kwargs)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...