Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

options.py 23 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import argparse
  8. import torch
  9. from fairseq.criterions import CRITERION_REGISTRY
  10. from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
  11. from fairseq.optim import OPTIMIZER_REGISTRY
  12. from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
  13. from fairseq.tasks import TASK_REGISTRY
  14. from fairseq.utils import import_user_module
  15. def get_preprocessing_parser(default_task='translation'):
  16. parser = get_parser('Preprocessing', default_task)
  17. add_preprocess_args(parser)
  18. return parser
  19. def get_training_parser(default_task='translation'):
  20. parser = get_parser('Trainer', default_task)
  21. add_dataset_args(parser, train=True)
  22. add_distributed_training_args(parser)
  23. add_model_args(parser)
  24. add_optimization_args(parser)
  25. add_checkpoint_args(parser)
  26. return parser
  27. def get_generation_parser(interactive=False, default_task='translation'):
  28. parser = get_parser('Generation', default_task)
  29. add_dataset_args(parser, gen=True)
  30. add_generation_args(parser)
  31. if interactive:
  32. add_interactive_args(parser)
  33. return parser
  34. def get_interactive_generation_parser(default_task='translation'):
  35. return get_generation_parser(interactive=True, default_task=default_task)
  36. def get_eval_lm_parser(default_task='language_modeling'):
  37. parser = get_parser('Evaluate Language Model', default_task)
  38. add_dataset_args(parser, gen=True)
  39. add_eval_lm_args(parser)
  40. return parser
  41. def eval_str_list(x, type=float):
  42. if x is None:
  43. return None
  44. if isinstance(x, str):
  45. x = eval(x)
  46. try:
  47. return list(map(type, x))
  48. except TypeError:
  49. return [type(x)]
  50. def eval_bool(x, default=False):
  51. if x is None:
  52. return default
  53. try:
  54. return bool(eval(x))
  55. except TypeError:
  56. return default
  57. def parse_args_and_arch(parser, input_args=None, parse_known=False):
  58. # The parser doesn't know about model/criterion/optimizer-specific args, so
  59. # we parse twice. First we parse the model/criterion/optimizer, then we
  60. # parse a second time after adding the *-specific arguments.
  61. # If input_args is given, we will parse those args instead of sys.argv.
  62. args, _ = parser.parse_known_args(input_args)
  63. # Add model-specific args to parser.
  64. if hasattr(args, 'arch'):
  65. model_specific_group = parser.add_argument_group(
  66. 'Model-specific configuration',
  67. # Only include attributes which are explicitly given as command-line
  68. # arguments or which have default values.
  69. argument_default=argparse.SUPPRESS,
  70. )
  71. ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
  72. # Add *-specific args to parser.
  73. if hasattr(args, 'criterion'):
  74. CRITERION_REGISTRY[args.criterion].add_args(parser)
  75. if hasattr(args, 'optimizer'):
  76. OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
  77. if hasattr(args, 'lr_scheduler'):
  78. LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
  79. if hasattr(args, 'task'):
  80. TASK_REGISTRY[args.task].add_args(parser)
  81. # Parse a second time.
  82. if parse_known:
  83. args, extra = parser.parse_known_args(input_args)
  84. else:
  85. args = parser.parse_args(input_args)
  86. extra = None
  87. # Post-process args.
  88. if hasattr(args, 'lr'):
  89. args.lr = eval_str_list(args.lr, type=float)
  90. if hasattr(args, 'update_freq'):
  91. args.update_freq = eval_str_list(args.update_freq, type=int)
  92. if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
  93. args.max_sentences_valid = args.max_sentences
  94. if getattr(args, 'memory_efficient_fp16', False):
  95. args.fp16 = True
  96. # Apply architecture configuration.
  97. if hasattr(args, 'arch'):
  98. ARCH_CONFIG_REGISTRY[args.arch](args)
  99. if parse_known:
  100. return args, extra
  101. else:
  102. return args
  103. def get_parser(desc, default_task='translation'):
  104. # Before creating the true parser, we need to import optional user module
  105. # in order to eagerly import custom tasks, optimizers, architectures, etc.
  106. usr_parser = argparse.ArgumentParser(add_help=False)
  107. usr_parser.add_argument('--user-dir', default=None)
  108. usr_args, _ = usr_parser.parse_known_args()
  109. import_user_module(usr_args)
  110. parser = argparse.ArgumentParser()
  111. # fmt: off
  112. parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
  113. parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
  114. help='log progress every N batches (when progress bar is disabled)')
  115. parser.add_argument('--log-format', default=None, help='log format to use',
  116. choices=['json', 'none', 'simple', 'tqdm'])
  117. parser.add_argument('--seed', default=1, type=int, metavar='N',
  118. help='pseudo random number generator seed')
  119. parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
  120. parser.add_argument('--fp16', action='store_true', help='use FP16')
  121. parser.add_argument('--memory-efficient-fp16', action='store_true',
  122. help='use a memory-efficient version of FP16 training; implies --fp16')
  123. parser.add_argument('--fp16-init-scale', default=2 ** 7, type=int,
  124. help='default FP16 loss scale')
  125. parser.add_argument('--fp16-scale-window', type=int,
  126. help='number of updates before increasing loss scale')
  127. parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
  128. help='pct of updates that can overflow before decreasing the loss scale')
  129. parser.add_argument('--user-dir', default=None,
  130. help='path to a python module containing custom extensions (tasks and/or architectures)')
  131. # Task definitions can be found under fairseq/tasks/
  132. parser.add_argument('--task', metavar='TASK', default=default_task,
  133. choices=TASK_REGISTRY.keys(),
  134. help='task')
  135. # fmt: on
  136. return parser
  137. def add_preprocess_args(parser):
  138. group = parser.add_argument_group('Preprocessing')
  139. # fmt: off
  140. group.add_argument("-s", "--source-lang", default=None, metavar="SRC",
  141. help="source language")
  142. group.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
  143. help="target language")
  144. group.add_argument("--trainpref", metavar="FP", default=None,
  145. help="train file prefix")
  146. group.add_argument("--validpref", metavar="FP", default=None,
  147. help="comma separated, valid file prefixes")
  148. group.add_argument("--testpref", metavar="FP", default=None,
  149. help="comma separated, test file prefixes")
  150. group.add_argument("--destdir", metavar="DIR", default="data-bin",
  151. help="destination dir")
  152. group.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
  153. help="map words appearing less than threshold times to unknown")
  154. group.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
  155. help="map words appearing less than threshold times to unknown")
  156. group.add_argument("--tgtdict", metavar="FP",
  157. help="reuse given target dictionary")
  158. group.add_argument("--srcdict", metavar="FP",
  159. help="reuse given source dictionary")
  160. group.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
  161. help="number of target words to retain")
  162. group.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
  163. help="number of source words to retain")
  164. group.add_argument("--alignfile", metavar="ALIGN", default=None,
  165. help="an alignment file (optional)")
  166. group.add_argument("--output-format", metavar="FORMAT", default="binary",
  167. choices=["binary", "raw"],
  168. help="output format (optional)")
  169. group.add_argument("--joined-dictionary", action="store_true",
  170. help="Generate joined dictionary")
  171. group.add_argument("--only-source", action="store_true",
  172. help="Only process the source language")
  173. group.add_argument("--padding-factor", metavar="N", default=8, type=int,
  174. help="Pad dictionary size to be multiple of N")
  175. group.add_argument("--workers", metavar="N", default=1, type=int,
  176. help="number of parallel workers")
  177. # fmt: on
  178. return parser
  179. def add_dataset_args(parser, train=False, gen=False):
  180. group = parser.add_argument_group('Dataset and data loading')
  181. # fmt: off
  182. group.add_argument('--num-workers', default=0, type=int, metavar='N',
  183. help='how many subprocesses to use for data loading')
  184. group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
  185. help='ignore too long or too short lines in valid and test set')
  186. group.add_argument('--max-tokens', type=int, metavar='N',
  187. help='maximum number of tokens in a batch')
  188. group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
  189. help='maximum number of sentences in a batch')
  190. if train:
  191. group.add_argument('--train-subset', default='train', metavar='SPLIT',
  192. choices=['train', 'valid', 'test'],
  193. help='data subset to use for training (train, valid, test)')
  194. group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
  195. help='comma separated list of data subsets to use for validation'
  196. ' (train, valid, valid1, test, test1)')
  197. group.add_argument('--max-sentences-valid', type=int, metavar='N',
  198. help='maximum number of sentences in a validation batch'
  199. ' (defaults to --max-sentences)')
  200. if gen:
  201. group.add_argument('--gen-subset', default='test', metavar='SPLIT',
  202. help='data subset to generate (train, valid, test)')
  203. group.add_argument('--num-shards', default=1, type=int, metavar='N',
  204. help='shard generation over N shards')
  205. group.add_argument('--shard-id', default=0, type=int, metavar='ID',
  206. help='id of the shard to generate (id < num_shards)')
  207. # fmt: on
  208. return group
  209. def add_distributed_training_args(parser):
  210. group = parser.add_argument_group('Distributed training')
  211. # fmt: off
  212. group.add_argument('--distributed-world-size', type=int, metavar='N',
  213. default=max(1, torch.cuda.device_count()),
  214. help='total number of GPUs across all nodes (default: all visible GPUs)')
  215. group.add_argument('--distributed-rank', default=0, type=int,
  216. help='rank of the current worker')
  217. group.add_argument('--distributed-backend', default='nccl', type=str,
  218. help='distributed backend')
  219. group.add_argument('--distributed-init-method', default=None, type=str,
  220. help='typically tcp://hostname:port that will be used to '
  221. 'establish initial connetion')
  222. group.add_argument('--distributed-port', default=-1, type=int,
  223. help='port number (not required if using --distributed-init-method)')
  224. group.add_argument('--device-id', '--local_rank', default=0, type=int,
  225. help='which GPU to use (usually configured automatically)')
  226. group.add_argument('--ddp-backend', default='c10d', type=str,
  227. choices=['c10d', 'no_c10d'],
  228. help='DistributedDataParallel backend')
  229. group.add_argument('--bucket-cap-mb', default=25, type=int, metavar='MB',
  230. help='bucket size for reduction')
  231. group.add_argument('--fix-batches-to-gpus', action='store_true',
  232. help='don\'t shuffle batches between GPUs; this reduces overall '
  233. 'randomness and may affect precision but avoids the cost of '
  234. 're-reading the data')
  235. # fmt: on
  236. return group
  237. def add_optimization_args(parser):
  238. group = parser.add_argument_group('Optimization')
  239. # fmt: off
  240. group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
  241. help='force stop training at specified epoch')
  242. group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
  243. help='force stop training at specified update')
  244. group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
  245. help='clip threshold of gradients')
  246. group.add_argument('--sentence-avg', action='store_true',
  247. help='normalize gradients by the number of sentences in a batch'
  248. ' (default is to normalize by number of tokens)')
  249. group.add_argument('--update-freq', default='1', metavar='N',
  250. help='update parameters every N_i batches, when in epoch i')
  251. # Optimizer definitions can be found under fairseq/optim/
  252. group.add_argument('--optimizer', default='nag', metavar='OPT',
  253. choices=OPTIMIZER_REGISTRY.keys(),
  254. help='Optimizer')
  255. group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR_1,LR_2,...,LR_N',
  256. help='learning rate for the first N epochs; all epochs >N using LR_N'
  257. ' (note: this may be interpreted differently depending on --lr-scheduler)')
  258. group.add_argument('--momentum', default=0.99, type=float, metavar='M',
  259. help='momentum factor')
  260. group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
  261. help='weight decay')
  262. # Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
  263. group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
  264. choices=LR_SCHEDULER_REGISTRY.keys(),
  265. help='Learning Rate Scheduler')
  266. group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
  267. help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
  268. group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
  269. help='minimum learning rate')
  270. group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
  271. help='minimum loss scale (for FP16 training)')
  272. # fmt: on
  273. return group
  274. def add_checkpoint_args(parser):
  275. group = parser.add_argument_group('Checkpointing')
  276. # fmt: off
  277. group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
  278. help='path to save checkpoints')
  279. group.add_argument('--restore-file', default='checkpoint_last.pt',
  280. help='filename in save-dir from which to load checkpoint')
  281. group.add_argument('--reset-optimizer', action='store_true',
  282. help='if set, does not load optimizer state from the checkpoint')
  283. group.add_argument('--reset-lr-scheduler', action='store_true',
  284. help='if set, does not load lr scheduler state from the checkpoint')
  285. group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
  286. help='a dictionary used to override optimizer args when loading a checkpoint')
  287. group.add_argument('--save-interval', type=int, default=1, metavar='N',
  288. help='save a checkpoint every N epochs')
  289. group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
  290. help='save a checkpoint (and validate) every N updates')
  291. group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
  292. help='keep the last N checkpoints saved with --save-interval-updates')
  293. group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
  294. help='keep last N epoch checkpoints')
  295. group.add_argument('--no-save', action='store_true',
  296. help='don\'t save models or checkpoints')
  297. group.add_argument('--no-epoch-checkpoints', action='store_true',
  298. help='only store last and best checkpoints')
  299. group.add_argument('--validate-interval', type=int, default=1, metavar='N',
  300. help='validate every N epochs')
  301. # fmt: on
  302. return group
  303. def add_common_eval_args(group):
  304. # fmt: off
  305. group.add_argument('--path', metavar='FILE',
  306. help='path(s) to model file(s), colon separated')
  307. group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
  308. help='remove BPE tokens before scoring (can be set to sentencepiece)')
  309. group.add_argument('--quiet', action='store_true',
  310. help='only print final scores')
  311. group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
  312. help='a dictionary used to override model args at generation '
  313. 'that were used during model training')
  314. # fmt: on
  315. def add_eval_lm_args(parser):
  316. group = parser.add_argument_group('LM Evaluation')
  317. add_common_eval_args(group)
  318. # fmt: off
  319. group.add_argument('--output-word-probs', action='store_true',
  320. help='if set, outputs words and their predicted log probabilities to standard output')
  321. group.add_argument('--output-word-stats', action='store_true',
  322. help='if set, outputs word statistics such as word count, average probability, etc')
  323. # fmt: on
  324. def add_generation_args(parser):
  325. group = parser.add_argument_group('Generation')
  326. add_common_eval_args(group)
  327. # fmt: off
  328. group.add_argument('--beam', default=5, type=int, metavar='N',
  329. help='beam size')
  330. group.add_argument('--nbest', default=1, type=int, metavar='N',
  331. help='number of hypotheses to output')
  332. group.add_argument('--max-len-a', default=0, type=float, metavar='N',
  333. help=('generate sequences of maximum length ax + b, '
  334. 'where x is the source length'))
  335. group.add_argument('--max-len-b', default=200, type=int, metavar='N',
  336. help=('generate sequences of maximum length ax + b, '
  337. 'where x is the source length'))
  338. group.add_argument('--min-len', default=1, type=float, metavar='N',
  339. help=('minimum generation length'))
  340. group.add_argument('--match-source-len', default=False, action='store_true',
  341. help=('generations should match the source length'))
  342. group.add_argument('--no-early-stop', action='store_true',
  343. help=('continue searching even after finalizing k=beam '
  344. 'hypotheses; this is more correct, but increases '
  345. 'generation time by 50%%'))
  346. group.add_argument('--unnormalized', action='store_true',
  347. help='compare unnormalized hypothesis scores')
  348. group.add_argument('--no-beamable-mm', action='store_true',
  349. help='don\'t use BeamableMM in attention layers')
  350. group.add_argument('--lenpen', default=1, type=float,
  351. help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
  352. group.add_argument('--unkpen', default=0, type=float,
  353. help='unknown word penalty: <0 produces more unks, >0 produces fewer')
  354. group.add_argument('--replace-unk', nargs='?', const=True, default=None,
  355. help='perform unknown replacement (optionally with alignment dictionary)')
  356. group.add_argument('--sacrebleu', action='store_true',
  357. help='score with sacrebleu')
  358. group.add_argument('--score-reference', action='store_true',
  359. help='just score the reference translation')
  360. group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
  361. help='initialize generation by target prefix of given length')
  362. group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N',
  363. help='ngram blocking such that this size ngram cannot be repeated in the generation')
  364. group.add_argument('--sampling', action='store_true',
  365. help='sample hypotheses instead of using beam search')
  366. group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
  367. help='sample from top K likely next words instead of all words')
  368. group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
  369. help='temperature for random sampling')
  370. group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
  371. help='number of groups for Diverse Beam Search')
  372. group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
  373. help='strength of diversity penalty for Diverse Beam Search')
  374. group.add_argument('--print-alignment', action='store_true',
  375. help='if set, uses attention feedback to compute and print alignment to source tokens')
  376. # fmt: on
  377. return group
  378. def add_interactive_args(parser):
  379. group = parser.add_argument_group('Interactive')
  380. # fmt: off
  381. group.add_argument('--buffer-size', default=0, type=int, metavar='N',
  382. help='read this many sentences into a buffer before processing them')
  383. group.add_argument('--input', default='-', type=str, metavar='FILE',
  384. help='file to read from; use - for stdin')
  385. # fmt: on
  386. def add_model_args(parser):
  387. group = parser.add_argument_group('Model configuration')
  388. # fmt: off
  389. # Model definitions can be found under fairseq/models/
  390. #
  391. # The model architecture can be specified in several ways.
  392. # In increasing order of priority:
  393. # 1) model defaults (lowest priority)
  394. # 2) --arch argument
  395. # 3) --encoder/decoder-* arguments (highest priority)
  396. group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True,
  397. choices=ARCH_MODEL_REGISTRY.keys(),
  398. help='Model Architecture')
  399. # Criterion definitions can be found under fairseq/criterions/
  400. group.add_argument('--criterion', default='cross_entropy', metavar='CRIT',
  401. choices=CRITERION_REGISTRY.keys(),
  402. help='Training Criterion')
  403. # fmt: on
  404. return group
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...