Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 13 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
  1. #!/usr/bin/env python3 -u
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the LICENSE file in
  6. # the root directory of this source tree. An additional grant of patent rights
  7. # can be found in the PATENTS file in the same directory.
  8. import collections
  9. import itertools
  10. import os
  11. import math
  12. import torch
  13. from fairseq import data, distributed_utils, options, progress_bar, tasks, utils
  14. from fairseq.fp16_trainer import FP16Trainer
  15. from fairseq.trainer import Trainer
  16. from fairseq.meters import AverageMeter, StopwatchMeter
  17. def main(args):
  18. if args.max_tokens is None:
  19. args.max_tokens = 6000
  20. print(args)
  21. if not torch.cuda.is_available():
  22. raise NotImplementedError('Training on CPU is not supported')
  23. torch.cuda.set_device(args.device_id)
  24. torch.manual_seed(args.seed)
  25. # Setup task, e.g., translation, language modeling, etc.
  26. task = tasks.setup_task(args)
  27. # Load dataset splits
  28. load_dataset_splits(args, task, ['train', 'valid'])
  29. # Build model and criterion
  30. model = task.build_model(args)
  31. criterion = task.build_criterion(args)
  32. print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
  33. print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
  34. # Build trainer
  35. if args.fp16:
  36. trainer = FP16Trainer(args, task, model, criterion)
  37. else:
  38. if torch.cuda.get_device_capability(0)[0] >= 7:
  39. print('| NOTICE: your device may support faster training with --fp16')
  40. trainer = Trainer(args, task, model, criterion)
  41. print('| training on {} GPUs'.format(args.distributed_world_size))
  42. print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
  43. args.max_tokens,
  44. args.max_sentences,
  45. ))
  46. # Initialize dataloader
  47. max_positions = trainer.get_model().max_positions()
  48. epoch_itr = data.EpochBatchIterator(
  49. dataset=task.dataset(args.train_subset),
  50. max_tokens=args.max_tokens,
  51. max_sentences=args.max_sentences_valid,
  52. max_positions=max_positions,
  53. ignore_invalid_inputs=True,
  54. required_batch_size_multiple=8,
  55. seed=args.seed,
  56. num_shards=args.distributed_world_size,
  57. shard_id=args.distributed_rank,
  58. )
  59. # Load the latest checkpoint if one is available
  60. load_checkpoint(args, trainer, epoch_itr)
  61. # Send a dummy batch to warm the caching allocator
  62. dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
  63. trainer.dummy_train_step(dummy_batch)
  64. # Train until the learning rate gets too small
  65. max_epoch = args.max_epoch or math.inf
  66. max_update = args.max_update or math.inf
  67. lr = trainer.get_lr()
  68. train_meter = StopwatchMeter()
  69. train_meter.start()
  70. valid_losses = [None]
  71. valid_subsets = args.valid_subset.split(',')
  72. while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates() < max_update:
  73. # train for one epoch
  74. train(args, trainer, task, epoch_itr)
  75. if epoch_itr.epoch % args.validate_interval == 0:
  76. valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
  77. # only use first validation loss to update the learning rate
  78. lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
  79. # save checkpoint
  80. if epoch_itr.epoch % args.save_interval == 0:
  81. save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
  82. train_meter.stop()
  83. print('| done training in {:.1f} seconds'.format(train_meter.sum))
  84. def train(args, trainer, task, epoch_itr):
  85. """Train the model for one epoch."""
  86. # Initialize data iterator
  87. itr = epoch_itr.next_epoch_itr()
  88. progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
  89. # update parameters every N batches
  90. if epoch_itr.epoch <= len(args.update_freq):
  91. update_freq = args.update_freq[epoch_itr.epoch - 1]
  92. else:
  93. update_freq = args.update_freq[-1]
  94. extra_meters = collections.defaultdict(lambda: AverageMeter())
  95. first_valid = args.valid_subset.split(',')[0]
  96. max_update = args.max_update or math.inf
  97. num_batches = len(epoch_itr)
  98. for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
  99. if i < num_batches - 1 and (i + 1) % update_freq > 0:
  100. # buffer updates according to --update-freq
  101. trainer.train_step(sample, update_params=False)
  102. continue
  103. else:
  104. log_output = trainer.train_step(sample, update_params=True)
  105. # log mid-epoch stats
  106. stats = get_training_stats(trainer)
  107. for k, v in log_output.items():
  108. if k in ['loss', 'nll_loss', 'sample_size']:
  109. continue # these are already logged above
  110. if 'loss' in k:
  111. extra_meters[k].update(v, log_output['sample_size'])
  112. else:
  113. extra_meters[k].update(v)
  114. stats[k] = extra_meters[k].avg
  115. progress.log(stats)
  116. # ignore the first mini-batch in words-per-second calculation
  117. if i == 0:
  118. trainer.get_meter('wps').reset()
  119. num_updates = trainer.get_num_updates()
  120. if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
  121. valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
  122. save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
  123. if num_updates >= max_update:
  124. break
  125. # log end-of-epoch stats
  126. stats = get_training_stats(trainer)
  127. for k, meter in extra_meters.items():
  128. stats[k] = meter.avg
  129. progress.print(stats)
  130. # reset training meters
  131. for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
  132. meter = trainer.get_meter(k)
  133. if meter is not None:
  134. meter.reset()
  135. def get_training_stats(trainer):
  136. stats = collections.OrderedDict()
  137. stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
  138. if trainer.get_meter('train_nll_loss').count > 0:
  139. nll_loss = trainer.get_meter('train_nll_loss').avg
  140. stats['nll_loss'] = '{:.3f}'.format(nll_loss)
  141. else:
  142. nll_loss = trainer.get_meter('train_loss').avg
  143. stats['ppl'] = get_perplexity(nll_loss)
  144. stats['wps'] = round(trainer.get_meter('wps').avg)
  145. stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
  146. stats['wpb'] = round(trainer.get_meter('wpb').avg)
  147. stats['bsz'] = round(trainer.get_meter('bsz').avg)
  148. stats['num_updates'] = trainer.get_num_updates()
  149. stats['lr'] = trainer.get_lr()
  150. stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
  151. stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
  152. stats['oom'] = trainer.get_meter('oom').avg
  153. if trainer.get_meter('loss_scale') is not None:
  154. stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
  155. stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
  156. return stats
  157. def validate(args, trainer, task, epoch_itr, subsets):
  158. """Evaluate the model on the validation set(s) and return the losses."""
  159. valid_losses = []
  160. for subset in subsets:
  161. # Initialize data iterator
  162. itr = data.EpochBatchIterator(
  163. dataset=task.dataset(subset),
  164. max_tokens=args.max_tokens,
  165. max_sentences=args.max_sentences_valid,
  166. max_positions=trainer.get_model().max_positions(),
  167. ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
  168. required_batch_size_multiple=8,
  169. seed=args.seed,
  170. num_shards=args.distributed_world_size,
  171. shard_id=args.distributed_rank,
  172. ).next_epoch_itr(shuffle=False)
  173. progress = progress_bar.build_progress_bar(
  174. args, itr, epoch_itr.epoch,
  175. prefix='valid on \'{}\' subset'.format(subset),
  176. no_progress_bar='simple'
  177. )
  178. # reset validation loss meters
  179. for k in ['valid_loss', 'valid_nll_loss']:
  180. meter = trainer.get_meter(k)
  181. if meter is not None:
  182. meter.reset()
  183. extra_meters = collections.defaultdict(lambda: AverageMeter())
  184. for sample in progress:
  185. log_output = trainer.valid_step(sample)
  186. for k, v in log_output.items():
  187. if k in ['loss', 'nll_loss', 'sample_size']:
  188. continue
  189. extra_meters[k].update(v)
  190. # log validation stats
  191. stats = get_valid_stats(trainer)
  192. for k, meter in extra_meters.items():
  193. stats[k] = meter.avg
  194. progress.print(stats)
  195. valid_losses.append(stats['valid_loss'])
  196. return valid_losses
  197. def get_valid_stats(trainer):
  198. stats = collections.OrderedDict()
  199. stats['valid_loss'] = trainer.get_meter('valid_loss').avg
  200. if trainer.get_meter('valid_nll_loss').count > 0:
  201. nll_loss = trainer.get_meter('valid_nll_loss').avg
  202. stats['valid_nll_loss'] = nll_loss
  203. else:
  204. nll_loss = trainer.get_meter('valid_loss').avg
  205. stats['valid_ppl'] = get_perplexity(nll_loss)
  206. stats['num_updates'] = trainer.get_num_updates()
  207. if hasattr(save_checkpoint, 'best'):
  208. stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
  209. return stats
  210. def get_perplexity(loss):
  211. try:
  212. return '{:.2f}'.format(math.pow(2, loss))
  213. except OverflowError:
  214. return float('inf')
  215. def save_checkpoint(args, trainer, epoch_itr, val_loss):
  216. if args.no_save or not distributed_utils.is_master(args):
  217. return
  218. epoch = epoch_itr.epoch
  219. end_of_epoch = epoch_itr.end_of_epoch()
  220. updates = trainer.get_num_updates()
  221. checkpoint_conds = collections.OrderedDict()
  222. checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
  223. end_of_epoch and not args.no_epoch_checkpoints and
  224. epoch % args.save_interval == 0
  225. )
  226. checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
  227. not end_of_epoch and args.save_interval_updates > 0 and
  228. updates % args.save_interval_updates == 0
  229. )
  230. checkpoint_conds['checkpoint_best.pt'] = (
  231. val_loss is not None and
  232. (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
  233. )
  234. checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
  235. prev_best = getattr(save_checkpoint, 'best', val_loss)
  236. if val_loss is not None:
  237. save_checkpoint.best = min(val_loss, prev_best)
  238. extra_state = {
  239. 'best': save_checkpoint.best,
  240. 'train_iterator': epoch_itr.state_dict(),
  241. 'val_loss': val_loss,
  242. }
  243. checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
  244. if len(checkpoints) > 0:
  245. for fn in checkpoints:
  246. if os.path.exists(fn):
  247. os.remove(fn)
  248. if not end_of_epoch and args.keep_interval_updates > 0:
  249. for cp in checkpoints:
  250. trainer.save_checkpoint(cp, extra_state)
  251. else:
  252. trainer.save_checkpoint(checkpoints[0], extra_state)
  253. for fn in checkpoints[1:]:
  254. os.symlink(os.path.basename(checkpoints[0]), fn)
  255. if not end_of_epoch and args.keep_interval_updates > 0:
  256. # remove old checkpoints; checkpoints are sorted in descending order
  257. checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
  258. for old_chk in checkpoints[args.keep_interval_updates:]:
  259. os.remove(old_chk)
  260. def load_checkpoint(args, trainer, epoch_itr):
  261. """Load a checkpoint and replay dataloader to match."""
  262. os.makedirs(args.save_dir, exist_ok=True)
  263. checkpoint_path = os.path.join(args.save_dir, args.restore_file)
  264. if os.path.isfile(checkpoint_path):
  265. extra_state = trainer.load_checkpoint(checkpoint_path)
  266. if extra_state is not None:
  267. # replay train iterator to match checkpoint
  268. epoch_itr.load_state_dict(extra_state['train_iterator'])
  269. print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
  270. checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
  271. trainer.lr_step(epoch_itr.epoch)
  272. trainer.lr_step_update(trainer.get_num_updates())
  273. if 'best' in extra_state:
  274. save_checkpoint.best = extra_state['best']
  275. def load_dataset_splits(args, task, splits):
  276. for split in splits:
  277. for k in itertools.count():
  278. split_k = split + (str(k) if k > 0 else '')
  279. try:
  280. task.load_dataset(split_k)
  281. print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
  282. except FileNotFoundError as e:
  283. if k > 0:
  284. break
  285. raise e
  286. if __name__ == '__main__':
  287. parser = options.get_training_parser()
  288. args = options.parse_args_and_arch(parser)
  289. if args.distributed_port > 0 or args.distributed_init_method is not None:
  290. from distributed_train import main as distributed_main
  291. distributed_main(args)
  292. elif args.distributed_world_size > 1:
  293. from multiprocessing_train import main as multiprocessing_main
  294. multiprocessing_main(args)
  295. else:
  296. main(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...