Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

overview.rst 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  1. Overview
  2. ========
  3. Fairseq can be extended through user-supplied `plug-ins
  4. <https://en.wikipedia.org/wiki/Plug-in_(computing)>`_. We support five kinds of
  5. plug-ins:
  6. - :ref:`Models` define the neural network architecture and encapsulate all of the
  7. learnable parameters.
  8. - :ref:`Criterions` compute the loss function given the model outputs and targets.
  9. - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
  10. Datasets, initializing the Model/Criterion and calculating the loss.
  11. - :ref:`Optimizers` update the Model parameters based on the gradients.
  12. - :ref:`Learning Rate Schedulers` update the learning rate over the course of
  13. training.
  14. **Training Flow**
  15. Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
  16. fairseq implements the following high-level training flow::
  17. for epoch in range(num_epochs):
  18. itr = task.get_batch_iterator(task.dataset('train'))
  19. for num_updates, batch in enumerate(itr):
  20. task.train_step(batch, model, criterion, optimizer)
  21. average_and_clip_gradients()
  22. optimizer.step()
  23. lr_scheduler.step_update(num_updates)
  24. lr_scheduler.step(epoch)
  25. where the default implementation for ``train.train_step`` is roughly::
  26. def train_step(self, batch, model, criterion, optimizer):
  27. loss = criterion(model, batch)
  28. optimizer.backward(loss)
  29. **Registering new plug-ins**
  30. New plug-ins are *registered* through a set of ``@register`` function
  31. decorators, for example::
  32. @register_model('my_lstm')
  33. class MyLSTM(FairseqModel):
  34. (...)
  35. Once registered, new plug-ins can be used with the existing :ref:`Command-line
  36. Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
  37. new plug-ins.
  38. **Loading plug-ins from another directory**
  39. New plug-ins can be defined in a custom module stored in the user system. In order to import the module, and make the plugin available to *fairseq*, the command line supports the ``--user-dir`` flag that can be used to specify a custom location for additional modules to load into *fairseq*.
  40. For example, assuming this directory tree::
  41. /home/user/my-module/
  42. └── __init__.py
  43. with ``__init__.py``::
  44. from fairseq.models import register_model_architecture
  45. from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
  46. @register_model_architecture('transformer', 'my_transformer')
  47. def transformer_mmt_big(args):
  48. transformer_vaswani_wmt_en_de_big(args)
  49. it is possible to invoke the ``train.py`` script with the new architecture with::
  50. python3 train.py ... --user-dir /home/user/my-module -a my_transformer --task translation
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...