Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

trainer.py 16 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. """
  8. Train a network across multiple GPUs.
  9. """
  10. from collections import OrderedDict
  11. from itertools import chain
  12. import torch
  13. from fairseq import distributed_utils, models, optim, utils
  14. from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter, ListMeter
  15. from fairseq.optim import lr_scheduler
  16. from fairseq.utils import save_json_metric
  17. class Trainer(object):
  18. """Main class for data parallel training.
  19. This class supports synchronous distributed data parallel training,
  20. where multiple workers each have a full model replica and gradients
  21. are accumulated across workers before each update. We use
  22. :class:`~torch.nn.parallel.DistributedDataParallel` to handle
  23. communication of the gradients across workers.
  24. """
  25. def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
  26. self.args = args
  27. self.task = task
  28. # copy model and criterion to current device
  29. self.criterion = criterion
  30. self._model = model
  31. self.cuda = torch.cuda.is_available() and not args.cpu
  32. if args.fp16:
  33. self._model = self._model.half()
  34. if self.cuda:
  35. self.criterion = self.criterion.cuda()
  36. self._model = self._model.cuda()
  37. self._dummy_batch = dummy_batch
  38. self._oom_batch = oom_batch
  39. self._lr_scheduler = None
  40. self._num_updates = 0
  41. self._optim_history = None
  42. self._optimizer = None
  43. self._wrapped_model = None
  44. self.init_meters(args)
  45. def init_meters(self, args):
  46. self.meters = OrderedDict()
  47. self.meters['train_loss'] = AverageMeter()
  48. self.meters['train_losses'] = ListMeter()
  49. self.meters['train_nll_loss'] = AverageMeter()
  50. self.meters['valid_loss'] = AverageMeter()
  51. self.meters['valid_losses'] = ListMeter()
  52. self.meters['valid_nll_loss'] = AverageMeter()
  53. self.meters['wps'] = TimeMeter() # words per second
  54. self.meters['ups'] = TimeMeter() # updates per second
  55. self.meters['wpb'] = AverageMeter() # words per batch
  56. self.meters['bsz'] = AverageMeter() # sentences per batch
  57. self.meters['gnorm'] = AverageMeter() # gradient norm
  58. self.meters['clip'] = AverageMeter() # % of updates clipped
  59. self.meters['oom'] = AverageMeter() # out of memory
  60. if args.fp16:
  61. self.meters['loss_scale'] = AverageMeter() # dynamic loss scale
  62. self.meters['wall'] = TimeMeter() # wall time in seconds
  63. self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
  64. @property
  65. def model(self):
  66. if self._wrapped_model is None:
  67. if self.args.distributed_world_size > 1:
  68. self._wrapped_model = models.DistributedFairseqModel(
  69. self.args, self._model,
  70. )
  71. else:
  72. self._wrapped_model = self._model
  73. return self._wrapped_model
  74. @property
  75. def optimizer(self):
  76. if self._optimizer is None:
  77. self._build_optimizer()
  78. return self._optimizer
  79. @property
  80. def lr_scheduler(self):
  81. if self._lr_scheduler is None:
  82. self._build_optimizer() # this will initialize self._lr_scheduler
  83. return self._lr_scheduler
  84. def _build_optimizer(self):
  85. params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
  86. if self.args.fp16:
  87. if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
  88. print('| WARNING: your device does NOT support faster training with --fp16, '
  89. 'please switch to FP32 which is likely to be faster')
  90. if self.args.memory_efficient_fp16:
  91. self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
  92. else:
  93. self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
  94. else:
  95. if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
  96. print('| NOTICE: your device may support faster training with --fp16')
  97. self._optimizer = optim.build_optimizer(self.args, params)
  98. # We should initialize the learning rate scheduler immediately after
  99. # building the optimizer, so that the initial learning rate is set.
  100. self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
  101. def save_checkpoint(self, filename, extra_state):
  102. """Save all training state in a checkpoint file."""
  103. if distributed_utils.is_master(self.args): # only save one checkpoint
  104. extra_state['train_meters'] = self.meters
  105. # TODO: Set directory?
  106. if self.meters['train_loss'].count > 10:
  107. save_json_metric('./metrics', self.meters['train_loss'].avg, 'train_loss')
  108. save_json_metric('./metrics', self.meters['wps'].avg, 'wps')
  109. utils.save_state(
  110. filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
  111. self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
  112. )
  113. def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
  114. """Load all training state from a checkpoint file."""
  115. extra_state, self._optim_history, last_optim_state = utils.load_model_state(
  116. filename, self.get_model(),
  117. )
  118. if last_optim_state is not None and not reset_optimizer:
  119. # rebuild optimizer after loading model, since params may have changed
  120. self._build_optimizer()
  121. # only reload optimizer and lr_scheduler if they match
  122. last_optim = self._optim_history[-1]
  123. assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
  124. 'criterion does not match; please reset the optimizer (--reset-optimizer)'
  125. assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
  126. 'optimizer does not match; please reset the optimizer (--reset-optimizer)'
  127. if not reset_lr_scheduler:
  128. self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
  129. self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
  130. self._num_updates = last_optim['num_updates']
  131. if extra_state is not None and 'train_meters' in extra_state:
  132. self.meters.update(extra_state['train_meters'])
  133. del extra_state['train_meters']
  134. # reset TimeMeters, since their start times don't make sense anymore
  135. for meter in self.meters.values():
  136. if isinstance(meter, TimeMeter):
  137. meter.reset()
  138. return extra_state
  139. def train_step(self, samples, dummy_batch=False):
  140. """Do forward, backward and parameter update."""
  141. self._set_seed()
  142. self.model.train()
  143. self.criterion.train()
  144. self.zero_grad()
  145. if not dummy_batch:
  146. self.meters['train_wall'].start()
  147. # forward and backward pass
  148. logging_outputs, sample_sizes, ooms = [], [], 0
  149. for i, sample in enumerate(samples):
  150. sample = self._prepare_sample(sample)
  151. if sample is None:
  152. # when sample is None, run forward/backward on a dummy batch
  153. # and ignore the resulting gradients
  154. sample = self._prepare_sample(self._dummy_batch)
  155. ignore_grad = True
  156. else:
  157. ignore_grad = False
  158. try:
  159. if self.args.distributed_world_size > 1:
  160. # Whenever *samples* contains more than one mini-batch, we
  161. # want to accumulate gradients locally and only call
  162. # all-reduce in the last backwards pass. Currently the
  163. # *need_reduction* flag is only supported by
  164. # LegacyDistributedDataParallel.
  165. if i < len(samples) - 1:
  166. self.model.accumulate_grads = True
  167. else:
  168. self.model.accumulate_grads = False
  169. # forward and backward
  170. loss, sample_size, logging_output = self.task.train_step(
  171. sample, self.model, self.criterion, self.optimizer,
  172. ignore_grad
  173. )
  174. if not ignore_grad:
  175. logging_outputs.append(logging_output)
  176. sample_sizes.append(sample_size)
  177. except RuntimeError as e:
  178. if 'out of memory' in str(e):
  179. print('| WARNING: ran out of memory, skipping batch')
  180. ooms += 1
  181. self.zero_grad()
  182. else:
  183. raise e
  184. if ooms > 0 and self._oom_batch is not None:
  185. self.handle_ooms(ooms)
  186. if dummy_batch:
  187. return None
  188. # gather logging outputs from all replicas
  189. if self.args.distributed_world_size > 1:
  190. logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
  191. [logging_outputs, sample_sizes, ooms],
  192. ))
  193. logging_outputs = list(chain.from_iterable(logging_outputs))
  194. sample_sizes = list(chain.from_iterable(sample_sizes))
  195. ooms = sum(ooms)
  196. self.meters['oom'].update(ooms, len(samples))
  197. if ooms == self.args.distributed_world_size * len(samples):
  198. print('| WARNING: OOM in all workers, skipping update')
  199. self.zero_grad()
  200. return None
  201. # aggregate logging outputs and sample sizes
  202. logging_output = self.task.aggregate_logging_outputs(
  203. logging_outputs, self.criterion
  204. )
  205. sample_size = self.task.grad_denom(sample_sizes, self.criterion)
  206. if not all(k in logging_output for k in ['ntokens', 'nsentences']):
  207. raise Exception((
  208. 'Please update the {}.aggregate_logging_outputs() method to '
  209. 'return ntokens and nsentences'
  210. ).format(self.task.__class__.__name__))
  211. try:
  212. # normalize grads by sample size
  213. self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
  214. # clip grads
  215. grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
  216. # take an optimization step
  217. self.optimizer.step()
  218. self._num_updates += 1
  219. # update learning rate
  220. self.lr_scheduler.step_update(self._num_updates)
  221. # update meters
  222. ntokens = logging_output.get('ntokens', 0)
  223. nsentences = logging_output.get('nsentences', 0)
  224. self.meters['wps'].update(ntokens)
  225. self.meters['ups'].update(1.)
  226. self.meters['wpb'].update(ntokens)
  227. self.meters['bsz'].update(nsentences)
  228. self.meters['gnorm'].update(grad_norm)
  229. self.meters['clip'].update(
  230. 1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
  231. )
  232. self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
  233. self.meters['train_losses'].update(logging_output.get('loss', 0))
  234. if 'nll_loss' in logging_output:
  235. self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
  236. except OverflowError as e:
  237. print('| WARNING: overflow detected, ' + str(e))
  238. self.zero_grad()
  239. logging_output = None
  240. if self.args.fp16:
  241. self.meters['loss_scale'].reset()
  242. self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
  243. self.meters['train_wall'].stop()
  244. return logging_output
  245. def valid_step(self, sample, raise_oom=False):
  246. """Do forward pass in evaluation mode."""
  247. with torch.no_grad():
  248. self.model.eval()
  249. self.criterion.eval()
  250. sample = self._prepare_sample(sample)
  251. if sample is None:
  252. sample = self._prepare_sample(self._dummy_batch)
  253. ignore_results = True
  254. else:
  255. ignore_results = False
  256. try:
  257. _loss, sample_size, logging_output = self.task.valid_step(
  258. sample, self.model, self.criterion
  259. )
  260. except RuntimeError as e:
  261. if 'out of memory' in str(e) and not raise_oom:
  262. print('| WARNING: ran out of memory, retrying batch')
  263. for p in self.model.parameters():
  264. if p.grad is not None:
  265. del p.grad # free some memory
  266. if self.cuda:
  267. torch.cuda.empty_cache()
  268. return self.valid_step(sample, raise_oom=True)
  269. else:
  270. raise e
  271. if ignore_results:
  272. logging_output, sample_size = {}, 0
  273. # gather logging outputs from all replicas
  274. if self.args.distributed_world_size > 1:
  275. logging_output, sample_size = zip(*distributed_utils.all_gather_list(
  276. [logging_output, sample_size],
  277. ))
  278. logging_output = list(logging_output)
  279. sample_size = list(sample_size)
  280. else:
  281. logging_output = [logging_output]
  282. sample_size = [sample_size]
  283. # aggregate logging outputs and sample sizes
  284. logging_output = self.task.aggregate_logging_outputs(
  285. logging_output, self.criterion
  286. )
  287. sample_size = self.task.grad_denom(
  288. sample_size, self.criterion
  289. )
  290. # update meters for validation
  291. ntokens = logging_output.get('ntokens', 0)
  292. self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
  293. if 'nll_loss' in logging_output:
  294. self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
  295. return logging_output
  296. def dummy_train_step(self, dummy_batch):
  297. """Dummy training step for warming caching allocator."""
  298. self.train_step(dummy_batch, dummy_batch=True)
  299. self.zero_grad()
  300. def handle_ooms(self, number_of_ooms):
  301. """
  302. c10d accumulates/syncs gradients between gpus during backward pass.
  303. In case of OOMs, gpus may fail to sync, so we manually iterate
  304. extra to make sure each gpu makes same number of iterations.
  305. """
  306. for _ in range(number_of_ooms):
  307. self.train_step([self._oom_batch], True)
  308. def zero_grad(self):
  309. self.optimizer.zero_grad()
  310. def lr_step(self, epoch, val_loss=None):
  311. """Adjust the learning rate based on the validation loss."""
  312. return self.lr_scheduler.step(epoch, val_loss)
  313. def lr_step_update(self, num_updates):
  314. """Update the learning rate after each update."""
  315. return self.lr_scheduler.step_update(num_updates)
  316. def get_lr(self):
  317. """Get the current learning rate."""
  318. return self.optimizer.get_lr()
  319. def get_model(self):
  320. """Get the (non-wrapped) model instance."""
  321. return self._model
  322. def get_meter(self, name):
  323. """Get a specific meter by name."""
  324. if name not in self.meters:
  325. return None
  326. return self.meters[name]
  327. def get_num_updates(self):
  328. """Get the number of parameters updates."""
  329. return self._num_updates
  330. def _prepare_sample(self, sample):
  331. if sample is None or len(sample) == 0:
  332. return None
  333. if self.cuda:
  334. sample = utils.move_to_cuda(sample)
  335. return sample
  336. def _set_seed(self):
  337. # Set seed based on args.seed and the update number so that we get
  338. # reproducible results when resuming from checkpoints
  339. seed = self.args.seed + self.get_num_updates()
  340. torch.manual_seed(seed)
  341. if self.cuda:
  342. torch.cuda.manual_seed(seed)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...