Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tutorial_simple_lstm.rst 21 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
  1. Tutorial: Simple LSTM
  2. =====================
  3. In this tutorial we will extend fairseq by adding a new
  4. :class:`~fairseq.models.FairseqModel` that encodes a source sentence with an
  5. LSTM and then passes the final hidden state to a second LSTM that decodes the
  6. target sentence (without attention).
  7. This tutorial covers:
  8. 1. **Writing an Encoder and Decoder** to encode/decode the source/target
  9. sentence, respectively.
  10. 2. **Registering a new Model** so that it can be used with the existing
  11. :ref:`Command-line tools`.
  12. 3. **Training the Model** using the existing command-line tools.
  13. 4. **Making generation faster** by modifying the Decoder to use
  14. :ref:`Incremental decoding`.
  15. 1. Building an Encoder and Decoder
  16. ----------------------------------
  17. In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
  18. should implement the :class:`~fairseq.models.FairseqEncoder` interface and
  19. Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
  20. These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
  21. and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
  22. Modules.
  23. Encoder
  24. ~~~~~~~
  25. Our Encoder will embed the tokens in the source sentence, feed them to a
  26. :class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
  27. save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
  28. import torch.nn as nn
  29. from fairseq import utils
  30. from fairseq.models import FairseqEncoder
  31. class SimpleLSTMEncoder(FairseqEncoder):
  32. def __init__(
  33. self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
  34. ):
  35. super().__init__(dictionary)
  36. self.args = args
  37. # Our encoder will embed the inputs before feeding them to the LSTM.
  38. self.embed_tokens = nn.Embedding(
  39. num_embeddings=len(dictionary),
  40. embedding_dim=embed_dim,
  41. padding_idx=dictionary.pad(),
  42. )
  43. self.dropout = nn.Dropout(p=dropout)
  44. # We'll use a single-layer, unidirectional LSTM for simplicity.
  45. self.lstm = nn.LSTM(
  46. input_size=embed_dim,
  47. hidden_size=hidden_dim,
  48. num_layers=1,
  49. bidirectional=False,
  50. )
  51. def forward(self, src_tokens, src_lengths):
  52. # The inputs to the ``forward()`` function are determined by the
  53. # Task, and in particular the ``'net_input'`` key in each
  54. # mini-batch. We discuss Tasks in the next tutorial, but for now just
  55. # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
  56. # has shape `(batch)`.
  57. # Note that the source is typically padded on the left. This can be
  58. # configured by adding the `--left-pad-source "False"` command-line
  59. # argument, but here we'll make the Encoder handle either kind of
  60. # padding by converting everything to be right-padded.
  61. if self.args.left_pad_source:
  62. # Convert left-padding to right-padding.
  63. src_tokens = utils.convert_padding_direction(
  64. src_tokens,
  65. padding_idx=self.dictionary.pad(),
  66. left_to_right=True
  67. )
  68. # Embed the source.
  69. x = self.embed_tokens(src_tokens)
  70. # Apply dropout.
  71. x = self.dropout(x)
  72. # Pack the sequence into a PackedSequence object to feed to the LSTM.
  73. x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
  74. # Get the output from the LSTM.
  75. _outputs, (final_hidden, _final_cell) = self.lstm(x)
  76. # Return the Encoder's output. This can be any object and will be
  77. # passed directly to the Decoder.
  78. return {
  79. # this will have shape `(bsz, hidden_dim)`
  80. 'final_hidden': final_hidden.squeeze(0),
  81. }
  82. # Encoders are required to implement this method so that we can rearrange
  83. # the order of the batch elements during inference (e.g., beam search).
  84. def reorder_encoder_out(self, encoder_out, new_order):
  85. """
  86. Reorder encoder output according to `new_order`.
  87. Args:
  88. encoder_out: output from the ``forward()`` method
  89. new_order (LongTensor): desired order
  90. Returns:
  91. `encoder_out` rearranged according to `new_order`
  92. """
  93. final_hidden = encoder_out['final_hidden']
  94. return {
  95. 'final_hidden': final_hidden.index_select(0, new_order),
  96. }
  97. Decoder
  98. ~~~~~~~
  99. Our Decoder will predict the next word, conditioned on the Encoder's final
  100. hidden state and an embedded representation of the previous target word -- which
  101. is sometimes called *input feeding* or *teacher forcing*. More specifically,
  102. we'll use a :class:`torch.nn.LSTM` to produce a sequence of hidden states that
  103. we'll project to the size of the output vocabulary to predict each target word.
  104. ::
  105. import torch
  106. from fairseq.models import FairseqDecoder
  107. class SimpleLSTMDecoder(FairseqDecoder):
  108. def __init__(
  109. self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
  110. dropout=0.1,
  111. ):
  112. super().__init__(dictionary)
  113. # Our decoder will embed the inputs before feeding them to the LSTM.
  114. self.embed_tokens = nn.Embedding(
  115. num_embeddings=len(dictionary),
  116. embedding_dim=embed_dim,
  117. padding_idx=dictionary.pad(),
  118. )
  119. self.dropout = nn.Dropout(p=dropout)
  120. # We'll use a single-layer, unidirectional LSTM for simplicity.
  121. self.lstm = nn.LSTM(
  122. # For the first layer we'll concatenate the Encoder's final hidden
  123. # state with the embedded target tokens.
  124. input_size=encoder_hidden_dim + embed_dim,
  125. hidden_size=hidden_dim,
  126. num_layers=1,
  127. bidirectional=False,
  128. )
  129. # Define the output projection.
  130. self.output_projection = nn.Linear(hidden_dim, len(dictionary))
  131. # During training Decoders are expected to take the entire target sequence
  132. # (shifted right by one position) and produce logits over the vocabulary.
  133. # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
  134. # ``dictionary.eos()``, followed by the target sequence.
  135. def forward(self, prev_output_tokens, encoder_out):
  136. """
  137. Args:
  138. prev_output_tokens (LongTensor): previous decoder outputs of shape
  139. `(batch, tgt_len)`, for input feeding/teacher forcing
  140. encoder_out (Tensor, optional): output from the encoder, used for
  141. encoder-side attention
  142. Returns:
  143. tuple:
  144. - the last decoder layer's output of shape
  145. `(batch, tgt_len, vocab)`
  146. - the last decoder layer's attention weights of shape
  147. `(batch, tgt_len, src_len)`
  148. """
  149. bsz, tgt_len = prev_output_tokens.size()
  150. # Extract the final hidden state from the Encoder.
  151. final_encoder_hidden = encoder_out['final_hidden']
  152. # Embed the target sequence, which has been shifted right by one
  153. # position and now starts with the end-of-sentence symbol.
  154. x = self.embed_tokens(prev_output_tokens)
  155. # Apply dropout.
  156. x = self.dropout(x)
  157. # Concatenate the Encoder's final hidden state to *every* embedded
  158. # target token.
  159. x = torch.cat(
  160. [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
  161. dim=2,
  162. )
  163. # Using PackedSequence objects in the Decoder is harder than in the
  164. # Encoder, since the targets are not sorted in descending length order,
  165. # which is a requirement of ``pack_padded_sequence()``. Instead we'll
  166. # feed nn.LSTM directly.
  167. initial_state = (
  168. final_encoder_hidden.unsqueeze(0), # hidden
  169. torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
  170. )
  171. output, _ = self.lstm(
  172. x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
  173. initial_state,
  174. )
  175. x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
  176. # Project the outputs to the size of the vocabulary.
  177. x = self.output_projection(x)
  178. # Return the logits and ``None`` for the attention weights
  179. return x, None
  180. 2. Registering the Model
  181. ------------------------
  182. Now that we've defined our Encoder and Decoder we must *register* our model with
  183. fairseq using the :func:`~fairseq.models.register_model` function decorator.
  184. Once the model is registered we'll be able to use it with the existing
  185. :ref:`Command-line Tools`.
  186. All registered models must implement the
  187. :class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
  188. models (i.e., any model with a single Encoder and Decoder), we can instead
  189. implement the :class:`~fairseq.models.FairseqModel` interface.
  190. Create a small wrapper class in the same file and register it in fairseq with
  191. the name ``'simple_lstm'``::
  192. from fairseq.models import FairseqModel, register_model
  193. # Note: the register_model "decorator" should immediately precede the
  194. # definition of the Model class.
  195. @register_model('simple_lstm')
  196. class SimpleLSTMModel(FairseqModel):
  197. @staticmethod
  198. def add_args(parser):
  199. # Models can override this method to add new command-line arguments.
  200. # Here we'll add some new command-line arguments to configure dropout
  201. # and the dimensionality of the embeddings and hidden states.
  202. parser.add_argument(
  203. '--encoder-embed-dim', type=int, metavar='N',
  204. help='dimensionality of the encoder embeddings',
  205. )
  206. parser.add_argument(
  207. '--encoder-hidden-dim', type=int, metavar='N',
  208. help='dimensionality of the encoder hidden state',
  209. )
  210. parser.add_argument(
  211. '--encoder-dropout', type=float, default=0.1,
  212. help='encoder dropout probability',
  213. )
  214. parser.add_argument(
  215. '--decoder-embed-dim', type=int, metavar='N',
  216. help='dimensionality of the decoder embeddings',
  217. )
  218. parser.add_argument(
  219. '--decoder-hidden-dim', type=int, metavar='N',
  220. help='dimensionality of the decoder hidden state',
  221. )
  222. parser.add_argument(
  223. '--decoder-dropout', type=float, default=0.1,
  224. help='decoder dropout probability',
  225. )
  226. @classmethod
  227. def build_model(cls, args, task):
  228. # Fairseq initializes models by calling the ``build_model()``
  229. # function. This provides more flexibility, since the returned model
  230. # instance can be of a different type than the one that was called.
  231. # In this case we'll just return a SimpleLSTMModel instance.
  232. # Initialize our Encoder and Decoder.
  233. encoder = SimpleLSTMEncoder(
  234. args=args,
  235. dictionary=task.source_dictionary,
  236. embed_dim=args.encoder_embed_dim,
  237. hidden_dim=args.encoder_hidden_dim,
  238. dropout=args.encoder_dropout,
  239. )
  240. decoder = SimpleLSTMDecoder(
  241. dictionary=task.target_dictionary,
  242. encoder_hidden_dim=args.encoder_hidden_dim,
  243. embed_dim=args.decoder_embed_dim,
  244. hidden_dim=args.decoder_hidden_dim,
  245. dropout=args.decoder_dropout,
  246. )
  247. model = SimpleLSTMModel(encoder, decoder)
  248. # Print the model architecture.
  249. print(model)
  250. return model
  251. # We could override the ``forward()`` if we wanted more control over how
  252. # the encoder and decoder interact, but it's not necessary for this
  253. # tutorial since we can inherit the default implementation provided by
  254. # the FairseqModel base class, which looks like:
  255. #
  256. # def forward(self, src_tokens, src_lengths, prev_output_tokens):
  257. # encoder_out = self.encoder(src_tokens, src_lengths)
  258. # decoder_out = self.decoder(prev_output_tokens, encoder_out)
  259. # return decoder_out
  260. Finally let's define a *named architecture* with the configuration for our
  261. model. This is done with the :func:`~fairseq.models.register_model_architecture`
  262. function decorator. Thereafter this named architecture can be used with the
  263. ``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
  264. from fairseq.models import register_model_architecture
  265. # The first argument to ``register_model_architecture()`` should be the name
  266. # of the model we registered above (i.e., 'simple_lstm'). The function we
  267. # register here should take a single argument *args* and modify it in-place
  268. # to match the desired architecture.
  269. @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
  270. def tutorial_simple_lstm(args):
  271. # We use ``getattr()`` to prioritize arguments that are explicitly given
  272. # on the command-line, so that the defaults defined below are only used
  273. # when no other value has been specified.
  274. args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
  275. args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
  276. args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
  277. args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
  278. 3. Training the Model
  279. ---------------------
  280. Now we're ready to train the model. We can use the existing :ref:`train.py`
  281. command-line tool for this, making sure to specify our new Model architecture
  282. (``--arch tutorial_simple_lstm``).
  283. .. note::
  284. Make sure you've already preprocessed the data from the IWSLT example in the
  285. :file:`examples/translation/` directory.
  286. .. code-block:: console
  287. > python train.py data-bin/iwslt14.tokenized.de-en \
  288. --arch tutorial_simple_lstm \
  289. --encoder-dropout 0.2 --decoder-dropout 0.2 \
  290. --optimizer adam --lr 0.005 --lr-shrink 0.5 \
  291. --max-tokens 12000
  292. (...)
  293. | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
  294. | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
  295. The model files should appear in the :file:`checkpoints/` directory. While this
  296. model architecture is not very good, we can use the :ref:`generate.py` script to
  297. generate translations and compute our BLEU score over the test set:
  298. .. code-block:: console
  299. > python generate.py data-bin/iwslt14.tokenized.de-en \
  300. --path checkpoints/checkpoint_best.pt \
  301. --beam 5 \
  302. --remove-bpe
  303. (...)
  304. | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  305. | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
  306. 4. Making generation faster
  307. ---------------------------
  308. While autoregressive generation from sequence-to-sequence models is inherently
  309. slow, our implementation above is especially slow because it recomputes the
  310. entire sequence of Decoder hidden states for every output token (i.e., it is
  311. ``O(n^2)``). We can make this significantly faster by instead caching the
  312. previous hidden states.
  313. In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
  314. special mode at inference time where the Model only receives a single timestep
  315. of input corresponding to the immediately previous output token (for input
  316. feeding) and must produce the next output incrementally. Thus the model must
  317. cache any long-term state that is needed about the sequence, e.g., hidden
  318. states, convolutional states, etc.
  319. To implement incremental decoding we will modify our model to implement the
  320. :class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
  321. standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
  322. decoder interface allows ``forward()`` methods to take an extra keyword argument
  323. (*incremental_state*) that can be used to cache state across time-steps.
  324. Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
  325. import torch
  326. from fairseq.models import FairseqIncrementalDecoder
  327. class SimpleLSTMDecoder(FairseqIncrementalDecoder):
  328. def __init__(
  329. self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
  330. dropout=0.1,
  331. ):
  332. # This remains the same as before.
  333. super().__init__(dictionary)
  334. self.embed_tokens = nn.Embedding(
  335. num_embeddings=len(dictionary),
  336. embedding_dim=embed_dim,
  337. padding_idx=dictionary.pad(),
  338. )
  339. self.dropout = nn.Dropout(p=dropout)
  340. self.lstm = nn.LSTM(
  341. input_size=encoder_hidden_dim + embed_dim,
  342. hidden_size=hidden_dim,
  343. num_layers=1,
  344. bidirectional=False,
  345. )
  346. self.output_projection = nn.Linear(hidden_dim, len(dictionary))
  347. # We now take an additional kwarg (*incremental_state*) for caching the
  348. # previous hidden and cell states.
  349. def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
  350. if incremental_state is not None:
  351. # If the *incremental_state* argument is not ``None`` then we are
  352. # in incremental inference mode. While *prev_output_tokens* will
  353. # still contain the entire decoded prefix, we will only use the
  354. # last step and assume that the rest of the state is cached.
  355. prev_output_tokens = prev_output_tokens[:, -1:]
  356. # This remains the same as before.
  357. bsz, tgt_len = prev_output_tokens.size()
  358. final_encoder_hidden = encoder_out['final_hidden']
  359. x = self.embed_tokens(prev_output_tokens)
  360. x = self.dropout(x)
  361. x = torch.cat(
  362. [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
  363. dim=2,
  364. )
  365. # We will now check the cache and load the cached previous hidden and
  366. # cell states, if they exist, otherwise we will initialize them to
  367. # zeros (as before). We will use the ``utils.get_incremental_state()``
  368. # and ``utils.set_incremental_state()`` helpers.
  369. initial_state = utils.get_incremental_state(
  370. self, incremental_state, 'prev_state',
  371. )
  372. if initial_state is None:
  373. # first time initialization, same as the original version
  374. initial_state = (
  375. final_encoder_hidden.unsqueeze(0), # hidden
  376. torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
  377. )
  378. # Run one step of our LSTM.
  379. output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
  380. # Update the cache with the latest hidden and cell states.
  381. utils.set_incremental_state(
  382. self, incremental_state, 'prev_state', latest_state,
  383. )
  384. # This remains the same as before
  385. x = output.transpose(0, 1)
  386. x = self.output_projection(x)
  387. return x, None
  388. # The ``FairseqIncrementalDecoder`` interface also requires implementing a
  389. # ``reorder_incremental_state()`` method, which is used during beam search
  390. # to select and reorder the incremental state.
  391. def reorder_incremental_state(self, incremental_state, new_order):
  392. # Load the cached state.
  393. prev_state = utils.get_incremental_state(
  394. self, incremental_state, 'prev_state',
  395. )
  396. # Reorder batches according to *new_order*.
  397. reordered_state = (
  398. prev_state[0].index_select(1, new_order), # hidden
  399. prev_state[1].index_select(1, new_order), # cell
  400. )
  401. # Update the cached state.
  402. utils.set_incremental_state(
  403. self, incremental_state, 'prev_state', reordered_state,
  404. )
  405. Finally, we can rerun generation and observe the speedup:
  406. .. code-block:: console
  407. # Before
  408. > python generate.py data-bin/iwslt14.tokenized.de-en \
  409. --path checkpoints/checkpoint_best.pt \
  410. --beam 5 \
  411. --remove-bpe
  412. (...)
  413. | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  414. | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
  415. # After
  416. > python generate.py data-bin/iwslt14.tokenized.de-en \
  417. --path checkpoints/checkpoint_best.pt \
  418. --beam 5 \
  419. --remove-bpe
  420. (...)
  421. | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
  422. | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...