1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import math
- import torch
- import torch.optim
- from . import FairseqOptimizer, register_optimizer
- @register_optimizer('adafactor')
- class FairseqAdafactor(FairseqOptimizer):
- def __init__(self, args, params):
- super().__init__(args, params)
- self._optimizer = Adafactor(params, **self.optimizer_config)
- @staticmethod
- def add_args(parser):
- """Add optimizer-specific arguments to the parser."""
- parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E",
- help='epsilons for Adafactor optimizer')
- parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C",
- help='threshold for clipping update root mean square')
- parser.add_argument('--decay-rate', type=float, default=-0.8, metavar="D",
- help='decay rate of the second moment estimator')
- parser.add_argument('--beta1', type=float, default=None, metavar="B",
- help='beta for first moment estimator. Optional')
- parser.add_argument('--scale-parameter', action='store_true',
- help='scale learning rate by root mean square of parameter.')
- parser.add_argument('--warmup-init', action='store_true',
- help='use relative step for warm-up learning rate schedule')
- parser.add_argument('--relative-step', action='store_true',
- help='set learning rate to inverse square root of timestep.'
- 'If false, external learning rate applied')
- @property
- def optimizer_config(self):
- """
- Return a kwarg dictionary that will be used to override optimizer
- args stored in checkpoints. This allows us to load a checkpoint and
- resume training using a different set of optimizer args, e.g., with a
- different learning rate.
- Note : Convergence issues empirically observed with fp16 on.
- Might require search for appropriate configuration.
- """
- return {
- 'lr': self.args.lr[0],
- 'eps': eval(self.args.adafactor_eps),
- 'clip_threshold': self.args.clip_threshold,
- 'beta1': self.args.beta1,
- 'decay_rate': self.args.decay_rate,
- 'scale_parameter': self.args.scale_parameter,
- 'weight_decay': self.args.weight_decay,
- 'relative_step': self.args.relative_step,
- 'warmup_init': self.args.warmup_init,
- }
- class Adafactor(torch.optim.Optimizer):
- """Implements Adafactor algorithm.
- This implementation is based on:
- `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
- (see https://arxiv.org/abs/1804.04235)
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): external learning rate (default: None)
- eps (tuple[float, float]): regularization constans for square gradient
- and parameter scale respectively (default: (1e-30, 1e-3))
- clip_threshold (float): threshold of root mean square of
- final gradient update (default: 1.0)
- decay_rate (float): coefficient used to compute running averages of square
- gradient (default: -0.8)
- beta1 (float): coefficient used for computing running averages of gradient
- (default: None)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- scale_parameter (bool): if true, learning rate is scaled by root mean square of
- parameter (default: True)
- relative_step (bool): if true, time-dependent learning rate is computed
- instead of external learning rate (default: True)
- warmup_init (bool): time-dependent learning rate computation depends on
- whether warm-up initialization is being used (default: False)
- """
- def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0,
- decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True,
- relative_step=True, warmup_init=False):
- defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate,
- beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
- relative_step=relative_step, warmup_init=warmup_init)
- super(Adafactor, self).__init__(params, defaults)
- def _get_lr(self, param_group, param_state):
- rel_step_sz = param_group['lr']
- if param_group['relative_step']:
- min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
- rel_step_sz = min(min_step, 1.0/math.sqrt(param_state['step']))
- param_scale = 1.0
- if param_group['scale_parameter']:
- param_scale = max(param_group['eps'][1], param_state['RMS'])
- return param_scale * rel_step_sz
- def _get_options(self, param_group, param_shape):
- factored = len(param_shape) >= 2
- use_first_moment = param_group['beta1'] is not None
- return factored, use_first_moment
- def _rms(self, tensor):
- return tensor.norm(2) / (tensor.numel() ** 0.5)
- def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, output):
- r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1)
- c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
- torch.mul(r_factor, c_factor, out=output)
- def step(self, closure=None):
- """Performs a single optimization step.
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- loss = closure()
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is None:
- continue
- grad = p.grad.data
- if grad.is_sparse:
- raise RuntimeError('Adafactor does not support sparse gradients.')
- state = self.state[p]
- grad_shape = grad.shape
- factored, use_first_moment = self._get_options(group, grad_shape)
- # State Initialization
- if len(state) == 0:
- state['step'] = 0
- if use_first_moment:
- # Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(grad)
- if factored:
- state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).type_as(grad)
- state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)
- else:
- state['exp_avg_sq'] = torch.zeros_like(grad)
- state['RMS'] = 0
- state['step'] += 1
- state['RMS'] = self._rms(p.data)
- lr = self._get_lr(group, state)
- beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
- update = (grad**2) + group['eps'][0]
- if factored:
- exp_avg_sq_row = state['exp_avg_sq_row']
- exp_avg_sq_col = state['exp_avg_sq_col']
- exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
- exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
- # Approximation of exponential moving average of square of gradient
- self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
- update.mul_(grad)
- else:
- exp_avg_sq = state['exp_avg_sq']
- exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
- torch.rsqrt(exp_avg_sq, out=update).mul_(grad)
- update.div_(max(1.0, self._rms(update) / group['clip_threshold']))
- update.mul_(lr)
- if use_first_moment:
- exp_avg = state['exp_avg']
- exp_avg.mul_(group['beta1']).add_(1 - group['beta1'], update)
- update = exp_avg
- if group['weight_decay'] != 0:
- p.data.add_(-group['weight_decay'] * lr, p.data)
- p.data.add_(-update)
- return loss
|