Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

adafactor.py 8.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import math
  8. import torch
  9. import torch.optim
  10. from . import FairseqOptimizer, register_optimizer
  11. @register_optimizer('adafactor')
  12. class FairseqAdafactor(FairseqOptimizer):
  13. def __init__(self, args, params):
  14. super().__init__(args, params)
  15. self._optimizer = Adafactor(params, **self.optimizer_config)
  16. @staticmethod
  17. def add_args(parser):
  18. """Add optimizer-specific arguments to the parser."""
  19. parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E",
  20. help='epsilons for Adafactor optimizer')
  21. parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C",
  22. help='threshold for clipping update root mean square')
  23. parser.add_argument('--decay-rate', type=float, default=-0.8, metavar="D",
  24. help='decay rate of the second moment estimator')
  25. parser.add_argument('--beta1', type=float, default=None, metavar="B",
  26. help='beta for first moment estimator. Optional')
  27. parser.add_argument('--scale-parameter', action='store_true',
  28. help='scale learning rate by root mean square of parameter.')
  29. parser.add_argument('--warmup-init', action='store_true',
  30. help='use relative step for warm-up learning rate schedule')
  31. parser.add_argument('--relative-step', action='store_true',
  32. help='set learning rate to inverse square root of timestep.'
  33. 'If false, external learning rate applied')
  34. @property
  35. def optimizer_config(self):
  36. """
  37. Return a kwarg dictionary that will be used to override optimizer
  38. args stored in checkpoints. This allows us to load a checkpoint and
  39. resume training using a different set of optimizer args, e.g., with a
  40. different learning rate.
  41. Note : Convergence issues empirically observed with fp16 on.
  42. Might require search for appropriate configuration.
  43. """
  44. return {
  45. 'lr': self.args.lr[0],
  46. 'eps': eval(self.args.adafactor_eps),
  47. 'clip_threshold': self.args.clip_threshold,
  48. 'beta1': self.args.beta1,
  49. 'decay_rate': self.args.decay_rate,
  50. 'scale_parameter': self.args.scale_parameter,
  51. 'weight_decay': self.args.weight_decay,
  52. 'relative_step': self.args.relative_step,
  53. 'warmup_init': self.args.warmup_init,
  54. }
  55. class Adafactor(torch.optim.Optimizer):
  56. """Implements Adafactor algorithm.
  57. This implementation is based on:
  58. `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
  59. (see https://arxiv.org/abs/1804.04235)
  60. Arguments:
  61. params (iterable): iterable of parameters to optimize or dicts defining
  62. parameter groups
  63. lr (float, optional): external learning rate (default: None)
  64. eps (tuple[float, float]): regularization constans for square gradient
  65. and parameter scale respectively (default: (1e-30, 1e-3))
  66. clip_threshold (float): threshold of root mean square of
  67. final gradient update (default: 1.0)
  68. decay_rate (float): coefficient used to compute running averages of square
  69. gradient (default: -0.8)
  70. beta1 (float): coefficient used for computing running averages of gradient
  71. (default: None)
  72. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  73. scale_parameter (bool): if true, learning rate is scaled by root mean square of
  74. parameter (default: True)
  75. relative_step (bool): if true, time-dependent learning rate is computed
  76. instead of external learning rate (default: True)
  77. warmup_init (bool): time-dependent learning rate computation depends on
  78. whether warm-up initialization is being used (default: False)
  79. """
  80. def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0,
  81. decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True,
  82. relative_step=True, warmup_init=False):
  83. defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate,
  84. beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
  85. relative_step=relative_step, warmup_init=warmup_init)
  86. super(Adafactor, self).__init__(params, defaults)
  87. def _get_lr(self, param_group, param_state):
  88. rel_step_sz = param_group['lr']
  89. if param_group['relative_step']:
  90. min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
  91. rel_step_sz = min(min_step, 1.0/math.sqrt(param_state['step']))
  92. param_scale = 1.0
  93. if param_group['scale_parameter']:
  94. param_scale = max(param_group['eps'][1], param_state['RMS'])
  95. return param_scale * rel_step_sz
  96. def _get_options(self, param_group, param_shape):
  97. factored = len(param_shape) >= 2
  98. use_first_moment = param_group['beta1'] is not None
  99. return factored, use_first_moment
  100. def _rms(self, tensor):
  101. return tensor.norm(2) / (tensor.numel() ** 0.5)
  102. def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, output):
  103. r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1)
  104. c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
  105. torch.mul(r_factor, c_factor, out=output)
  106. def step(self, closure=None):
  107. """Performs a single optimization step.
  108. Arguments:
  109. closure (callable, optional): A closure that reevaluates the model
  110. and returns the loss.
  111. """
  112. loss = None
  113. if closure is not None:
  114. loss = closure()
  115. for group in self.param_groups:
  116. for p in group['params']:
  117. if p.grad is None:
  118. continue
  119. grad = p.grad.data
  120. if grad.is_sparse:
  121. raise RuntimeError('Adafactor does not support sparse gradients.')
  122. state = self.state[p]
  123. grad_shape = grad.shape
  124. factored, use_first_moment = self._get_options(group, grad_shape)
  125. # State Initialization
  126. if len(state) == 0:
  127. state['step'] = 0
  128. if use_first_moment:
  129. # Exponential moving average of gradient values
  130. state['exp_avg'] = torch.zeros_like(grad)
  131. if factored:
  132. state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).type_as(grad)
  133. state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)
  134. else:
  135. state['exp_avg_sq'] = torch.zeros_like(grad)
  136. state['RMS'] = 0
  137. state['step'] += 1
  138. state['RMS'] = self._rms(p.data)
  139. lr = self._get_lr(group, state)
  140. beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
  141. update = (grad**2) + group['eps'][0]
  142. if factored:
  143. exp_avg_sq_row = state['exp_avg_sq_row']
  144. exp_avg_sq_col = state['exp_avg_sq_col']
  145. exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
  146. exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
  147. # Approximation of exponential moving average of square of gradient
  148. self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
  149. update.mul_(grad)
  150. else:
  151. exp_avg_sq = state['exp_avg_sq']
  152. exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
  153. torch.rsqrt(exp_avg_sq, out=update).mul_(grad)
  154. update.div_(max(1.0, self._rms(update) / group['clip_threshold']))
  155. update.mul_(lr)
  156. if use_first_moment:
  157. exp_avg = state['exp_avg']
  158. exp_avg.mul_(group['beta1']).add_(1 - group['beta1'], update)
  159. update = exp_avg
  160. if group['weight_decay'] != 0:
  161. p.data.add_(-group['weight_decay'] * lr, p.data)
  162. p.data.add_(-update)
  163. return loss
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...