Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

__init__.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import argparse
  8. import importlib
  9. import os
  10. from .fairseq_task import FairseqTask
  11. TASK_REGISTRY = {}
  12. TASK_CLASS_NAMES = set()
  13. def setup_task(args):
  14. return TASK_REGISTRY[args.task].setup_task(args)
  15. def register_task(name):
  16. """
  17. New tasks can be added to fairseq with the
  18. :func:`~fairseq.tasks.register_task` function decorator.
  19. For example::
  20. @register_task('classification')
  21. class ClassificationTask(FairseqTask):
  22. (...)
  23. .. note::
  24. All Tasks must implement the :class:`~fairseq.tasks.FairseqTask`
  25. interface.
  26. Please see the
  27. Args:
  28. name (str): the name of the task
  29. """
  30. def register_task_cls(cls):
  31. if name in TASK_REGISTRY:
  32. raise ValueError('Cannot register duplicate task ({})'.format(name))
  33. if not issubclass(cls, FairseqTask):
  34. raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__))
  35. if cls.__name__ in TASK_CLASS_NAMES:
  36. raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__))
  37. TASK_REGISTRY[name] = cls
  38. TASK_CLASS_NAMES.add(cls.__name__)
  39. return cls
  40. return register_task_cls
  41. # automatically import any Python files in the tasks/ directory
  42. for file in os.listdir(os.path.dirname(__file__)):
  43. if file.endswith('.py') and not file.startswith('_'):
  44. task_name = file[:file.find('.py')]
  45. importlib.import_module('fairseq.tasks.' + task_name)
  46. # expose `task_parser` for sphinx
  47. if task_name in TASK_REGISTRY:
  48. parser = argparse.ArgumentParser(add_help=False)
  49. group_task = parser.add_argument_group('Task name')
  50. # fmt: off
  51. group_task.add_argument('--task', metavar=task_name,
  52. help='Enable this task with: ``--task=' + task_name + '``')
  53. # fmt: on
  54. group_args = parser.add_argument_group('Additional command-line arguments')
  55. TASK_REGISTRY[task_name].add_args(group_args)
  56. globals()[task_name + '_parser'] = parser
  57. def get_task(name):
  58. return TASK_REGISTRY[name]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...