Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tutorial_classifying_names.rst 17 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
  1. Tutorial: Classifying Names with a Character-Level RNN
  2. ======================================================
  3. In this tutorial we will extend fairseq to support *classification* tasks. In
  4. particular we will re-implement the PyTorch tutorial for `Classifying Names with
  5. a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
  6. in fairseq. It is recommended to quickly skim that tutorial before beginning
  7. this one.
  8. This tutorial covers:
  9. 1. **Preprocessing the data** to create dictionaries.
  10. 2. **Registering a new Model** that encodes an input sentence with a simple RNN
  11. and predicts the output label.
  12. 3. **Registering a new Task** that loads our dictionaries and dataset.
  13. 4. **Training the Model** using the existing command-line tools.
  14. 5. **Writing an evaluation script** that imports fairseq and allows us to
  15. interactively evaluate our model on new inputs.
  16. 1. Preprocessing the data
  17. -------------------------
  18. The original tutorial provides raw data, but we'll work with a modified version
  19. of the data that is already tokenized into characters and split into separate
  20. train, valid and test sets.
  21. Download and extract the data from here:
  22. `tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
  23. Once extracted, let's preprocess the data using the :ref:`preprocess.py`
  24. command-line tool to create the dictionaries. While this tool is primarily
  25. intended for sequence-to-sequence problems, we're able to reuse it here by
  26. treating the label as a "target" sequence of length 1. We'll also output the
  27. preprocessed files in "raw" format using the ``--output-format`` option to
  28. enhance readability:
  29. .. code-block:: console
  30. > python preprocess.py \
  31. --trainpref names/train --validpref names/valid --testpref names/test \
  32. --source-lang input --target-lang label \
  33. --destdir names-bin --output-format raw
  34. After running the above command you should see a new directory,
  35. :file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
  36. 2. Registering a new Model
  37. --------------------------
  38. Next we'll register a new model in fairseq that will encode an input sentence
  39. with a simple RNN and predict the output label. Compared to the original PyTorch
  40. tutorial, our version will also work with batches of data and GPU Tensors.
  41. First let's copy the simple RNN module implemented in the `PyTorch tutorial
  42. <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
  43. Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
  44. following contents::
  45. import torch
  46. import torch.nn as nn
  47. class RNN(nn.Module):
  48. def __init__(self, input_size, hidden_size, output_size):
  49. super(RNN, self).__init__()
  50. self.hidden_size = hidden_size
  51. self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
  52. self.i2o = nn.Linear(input_size + hidden_size, output_size)
  53. self.softmax = nn.LogSoftmax(dim=1)
  54. def forward(self, input, hidden):
  55. combined = torch.cat((input, hidden), 1)
  56. hidden = self.i2h(combined)
  57. output = self.i2o(combined)
  58. output = self.softmax(output)
  59. return output, hidden
  60. def initHidden(self):
  61. return torch.zeros(1, self.hidden_size)
  62. We must also *register* this model with fairseq using the
  63. :func:`~fairseq.models.register_model` function decorator. Once the model is
  64. registered we'll be able to use it with the existing :ref:`Command-line Tools`.
  65. All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
  66. interface, so we'll create a small wrapper class in the same file and register
  67. it in fairseq with the name ``'rnn_classifier'``::
  68. from fairseq.models import BaseFairseqModel, register_model
  69. # Note: the register_model "decorator" should immediately precede the
  70. # definition of the Model class.
  71. @register_model('rnn_classifier')
  72. class FairseqRNNClassifier(BaseFairseqModel):
  73. @staticmethod
  74. def add_args(parser):
  75. # Models can override this method to add new command-line arguments.
  76. # Here we'll add a new command-line argument to configure the
  77. # dimensionality of the hidden state.
  78. parser.add_argument(
  79. '--hidden-dim', type=int, metavar='N',
  80. help='dimensionality of the hidden state',
  81. )
  82. @classmethod
  83. def build_model(cls, args, task):
  84. # Fairseq initializes models by calling the ``build_model()``
  85. # function. This provides more flexibility, since the returned model
  86. # instance can be of a different type than the one that was called.
  87. # In this case we'll just return a FairseqRNNClassifier instance.
  88. # Initialize our RNN module
  89. rnn = RNN(
  90. # We'll define the Task in the next section, but for now just
  91. # notice that the task holds the dictionaries for the "source"
  92. # (i.e., the input sentence) and "target" (i.e., the label).
  93. input_size=len(task.source_dictionary),
  94. hidden_size=args.hidden_dim,
  95. output_size=len(task.target_dictionary),
  96. )
  97. # Return the wrapped version of the module
  98. return FairseqRNNClassifier(
  99. rnn=rnn,
  100. input_vocab=task.source_dictionary,
  101. )
  102. def __init__(self, rnn, input_vocab):
  103. super(FairseqRNNClassifier, self).__init__()
  104. self.rnn = rnn
  105. self.input_vocab = input_vocab
  106. # The RNN module in the tutorial expects one-hot inputs, so we can
  107. # precompute the identity matrix to help convert from indices to
  108. # one-hot vectors. We register it as a buffer so that it is moved to
  109. # the GPU when ``cuda()`` is called.
  110. self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
  111. def forward(self, src_tokens, src_lengths):
  112. # The inputs to the ``forward()`` function are determined by the
  113. # Task, and in particular the ``'net_input'`` key in each
  114. # mini-batch. We'll define the Task in the next section, but for
  115. # now just know that *src_tokens* has shape `(batch, src_len)` and
  116. # *src_lengths* has shape `(batch)`.
  117. bsz, max_src_len = src_tokens.size()
  118. # Initialize the RNN hidden state. Compared to the original PyTorch
  119. # tutorial we'll also handle batched inputs and work on the GPU.
  120. hidden = self.rnn.initHidden()
  121. hidden = hidden.repeat(bsz, 1) # expand for batched inputs
  122. hidden = hidden.to(src_tokens.device) # move to GPU
  123. for i in range(max_src_len):
  124. # WARNING: The inputs have padding, so we should mask those
  125. # elements here so that padding doesn't affect the results.
  126. # This is left as an exercise for the reader. The padding symbol
  127. # is given by ``self.input_vocab.pad()`` and the unpadded length
  128. # of each input is given by *src_lengths*.
  129. # One-hot encode a batch of input characters.
  130. input = self.one_hot_inputs[src_tokens[:, i].long()]
  131. # Feed the input to our RNN.
  132. output, hidden = self.rnn(input, hidden)
  133. # Return the final output state for making a prediction
  134. return output
  135. Finally let's define a *named architecture* with the configuration for our
  136. model. This is done with the :func:`~fairseq.models.register_model_architecture`
  137. function decorator. Thereafter this named architecture can be used with the
  138. ``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
  139. from fairseq.models import register_model_architecture
  140. # The first argument to ``register_model_architecture()`` should be the name
  141. # of the model we registered above (i.e., 'rnn_classifier'). The function we
  142. # register here should take a single argument *args* and modify it in-place
  143. # to match the desired architecture.
  144. @register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
  145. def pytorch_tutorial_rnn(args):
  146. # We use ``getattr()`` to prioritize arguments that are explicitly given
  147. # on the command-line, so that the defaults defined below are only used
  148. # when no other value has been specified.
  149. args.hidden_dim = getattr(args, 'hidden_dim', 128)
  150. 3. Registering a new Task
  151. -------------------------
  152. Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
  153. dictionaries and dataset. Tasks can also control how the data is batched into
  154. mini-batches, but in this tutorial we'll reuse the batching provided by
  155. :class:`fairseq.data.LanguagePairDataset`.
  156. Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
  157. following contents::
  158. import os
  159. import torch
  160. from fairseq.data import Dictionary, LanguagePairDataset
  161. from fairseq.tasks import FairseqTask, register_task
  162. from fairseq.tokenizer import Tokenizer
  163. @register_task('simple_classification')
  164. class SimpleClassificationTask(FairseqTask):
  165. @staticmethod
  166. def add_args(parser):
  167. # Add some command-line arguments for specifying where the data is
  168. # located and the maximum supported input length.
  169. parser.add_argument('data', metavar='FILE',
  170. help='file prefix for data')
  171. parser.add_argument('--max-positions', default=1024, type=int,
  172. help='max input length')
  173. @classmethod
  174. def setup_task(cls, args, **kwargs):
  175. # Here we can perform any setup required for the task. This may include
  176. # loading Dictionaries, initializing shared Embedding layers, etc.
  177. # In this case we'll just load the Dictionaries.
  178. input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
  179. label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
  180. print('| [input] dictionary: {} types'.format(len(input_vocab)))
  181. print('| [label] dictionary: {} types'.format(len(label_vocab)))
  182. return SimpleClassificationTask(args, input_vocab, label_vocab)
  183. def __init__(self, args, input_vocab, label_vocab):
  184. super().__init__(args)
  185. self.input_vocab = input_vocab
  186. self.label_vocab = label_vocab
  187. def load_dataset(self, split, **kwargs):
  188. """Load a given dataset split (e.g., train, valid, test)."""
  189. prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
  190. # Read input sentences.
  191. sentences, lengths = [], []
  192. with open(prefix + '.input', encoding='utf-8') as file:
  193. for line in file:
  194. sentence = line.strip()
  195. # Tokenize the sentence, splitting on spaces
  196. tokens = Tokenizer.tokenize(
  197. sentence, self.input_vocab, add_if_not_exist=False,
  198. )
  199. sentences.append(tokens)
  200. lengths.append(tokens.numel())
  201. # Read labels.
  202. labels = []
  203. with open(prefix + '.label', encoding='utf-8') as file:
  204. for line in file:
  205. label = line.strip()
  206. labels.append(
  207. # Convert label to a numeric ID.
  208. torch.LongTensor([self.label_vocab.add_symbol(label)])
  209. )
  210. assert len(sentences) == len(labels)
  211. print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
  212. # We reuse LanguagePairDataset since classification can be modeled as a
  213. # sequence-to-sequence task where the target sequence has length 1.
  214. self.datasets[split] = LanguagePairDataset(
  215. src=sentences,
  216. src_sizes=lengths,
  217. src_dict=self.input_vocab,
  218. tgt=labels,
  219. tgt_sizes=torch.ones(len(labels)), # targets have length 1
  220. tgt_dict=self.label_vocab,
  221. left_pad_source=False,
  222. max_source_positions=self.args.max_positions,
  223. max_target_positions=1,
  224. # Since our target is a single class label, there's no need for
  225. # input feeding. If we set this to ``True`` then our Model's
  226. # ``forward()`` method would receive an additional argument called
  227. # *prev_output_tokens* that would contain a shifted version of the
  228. # target sequence.
  229. input_feeding=False,
  230. )
  231. def max_positions(self):
  232. """Return the max input length allowed by the task."""
  233. # The source should be less than *args.max_positions* and the "target"
  234. # has max length 1.
  235. return (self.args.max_positions, 1)
  236. @property
  237. def source_dictionary(self):
  238. """Return the source :class:`~fairseq.data.Dictionary`."""
  239. return self.input_vocab
  240. @property
  241. def target_dictionary(self):
  242. """Return the target :class:`~fairseq.data.Dictionary`."""
  243. return self.label_vocab
  244. # We could override this method if we wanted more control over how batches
  245. # are constructed, but it's not necessary for this tutorial since we can
  246. # reuse the batching provided by LanguagePairDataset.
  247. #
  248. # def get_batch_iterator(
  249. # self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
  250. # ignore_invalid_inputs=False, required_batch_size_multiple=1,
  251. # seed=1, num_shards=1, shard_id=0,
  252. # ):
  253. # (...)
  254. 4. Training the Model
  255. ---------------------
  256. Now we're ready to train the model. We can use the existing :ref:`train.py`
  257. command-line tool for this, making sure to specify our new Task (``--task
  258. simple_classification``) and Model architecture (``--arch
  259. pytorch_tutorial_rnn``):
  260. .. note::
  261. You can also configure the dimensionality of the hidden state by passing the
  262. ``--hidden-dim`` argument to :ref:`train.py`.
  263. .. code-block:: console
  264. > python train.py names-bin \
  265. --task simple_classification \
  266. --arch pytorch_tutorial_rnn \
  267. --optimizer adam --lr 0.001 --lr-shrink 0.5 \
  268. --max-tokens 1000
  269. (...)
  270. | epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
  271. | epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
  272. | done training in 31.6 seconds
  273. The model files should appear in the :file:`checkpoints/` directory.
  274. 5. Writing an evaluation script
  275. -------------------------------
  276. Finally we can write a short script to evaluate our model on new inputs. Create
  277. a new file named :file:`eval_classifier.py` with the following contents::
  278. from fairseq import data, options, tasks, utils
  279. from fairseq.tokenizer import Tokenizer
  280. # Parse command-line arguments for generation
  281. parser = options.get_generation_parser(default_task='simple_classification')
  282. args = options.parse_args_and_arch(parser)
  283. # Setup task
  284. task = tasks.setup_task(args)
  285. # Load model
  286. print('| loading model from {}'.format(args.path))
  287. models, _model_args = utils.load_ensemble_for_inference([args.path], task)
  288. model = models[0]
  289. while True:
  290. sentence = input('\nInput: ')
  291. # Tokenize into characters
  292. chars = ' '.join(list(sentence.strip()))
  293. tokens = Tokenizer.tokenize(
  294. chars, task.source_dictionary, add_if_not_exist=False,
  295. )
  296. # Build mini-batch to feed to the model
  297. batch = data.language_pair_dataset.collate(
  298. samples=[{'id': -1, 'source': tokens}], # bsz = 1
  299. pad_idx=task.source_dictionary.pad(),
  300. eos_idx=task.source_dictionary.eos(),
  301. left_pad_source=False,
  302. input_feeding=False,
  303. )
  304. # Feed batch to the model and get predictions
  305. preds = model(**batch['net_input'])
  306. # Print top 3 predictions and their log-probabilities
  307. top_scores, top_labels = preds[0].topk(k=3)
  308. for score, label_idx in zip(top_scores, top_labels):
  309. label_name = task.target_dictionary.string([label_idx])
  310. print('({:.2f})\t{}'.format(score, label_name))
  311. Now we can evaluate our model interactively. Note that we have included the
  312. original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
  313. .. code-block:: console
  314. > python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
  315. | [input] dictionary: 64 types
  316. | [label] dictionary: 24 types
  317. | loading model from checkpoints/checkpoint_best.pt
  318. Input: Satoshi
  319. (-0.61) Japanese
  320. (-1.20) Arabic
  321. (-2.86) Italian
  322. Input: Sinbad
  323. (-0.30) Arabic
  324. (-1.76) English
  325. (-4.08) Russian
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...