1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- from typing import Dict, List, Optional
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from . import FairseqDecoder, FairseqEncoder
- from fairseq.data import Dictionary
- class BaseFairseqModel(nn.Module):
- """Base class for fairseq models."""
- def __init__(self):
- super().__init__()
- self._is_generation_fast = False
- @staticmethod
- def add_args(parser):
- """Add model-specific arguments to the parser."""
- pass
- @classmethod
- def build_model(cls, args, task):
- """Build a new model instance."""
- raise NotImplementedError('FairseqModels must implement the build_model method')
- def get_targets(self, sample, net_output):
- """Get targets from either the sample or the net's output."""
- return sample['target']
- def get_normalized_probs(self, net_output, log_probs, sample=None):
- """Get normalized probabilities (or log probs) from a net's output."""
- if hasattr(self, 'decoder'):
- return self.decoder.get_normalized_probs(net_output, log_probs, sample)
- elif torch.is_tensor(net_output):
- logits = net_output.float()
- if log_probs:
- return F.log_softmax(logits, dim=-1)
- else:
- return F.softmax(logits, dim=-1)
- raise NotImplementedError
- def max_positions(self):
- """Maximum length supported by the model."""
- return None
- def max_decoder_positions(self):
- """Maximum length supported by the decoder."""
- return self.decoder.max_positions()
- def load_state_dict(self, state_dict, strict=True):
- """Copies parameters and buffers from *state_dict* into this module and
- its descendants.
- Overrides the method in :class:`nn.Module`. Compared with that method
- this additionally "upgrades" *state_dicts* from old checkpoints.
- """
- self.upgrade_state_dict(state_dict)
- super().load_state_dict(state_dict, strict)
- def upgrade_state_dict(self, state_dict):
- """Upgrade old state dicts to work with newer code."""
- self.upgrade_state_dict_named(state_dict, '')
- def upgrade_state_dict_named(self, state_dict, name):
- """Upgrade old state dicts to work with newer code.
- Args:
- state_dict (dict): state dictionary to upgrade, in place
- name (str): the state dict key corresponding to the current module
- """
- assert state_dict is not None
- def do_upgrade(m, prefix):
- if len(prefix) > 0:
- prefix += '.'
- for n, c in m.named_children():
- name = prefix + n
- if hasattr(c, 'upgrade_state_dict_named'):
- c.upgrade_state_dict_named(state_dict, name)
- elif hasattr(c, 'upgrade_state_dict'):
- c.upgrade_state_dict(state_dict)
- do_upgrade(c, name)
- do_upgrade(self, name)
- def make_generation_fast_(self, **kwargs):
- """Optimize model for faster generation."""
- if self._is_generation_fast:
- return # only apply once
- self._is_generation_fast = True
- # remove weight norm from all modules in the network
- def apply_remove_weight_norm(module):
- try:
- nn.utils.remove_weight_norm(module)
- except ValueError: # this module didn't have weight norm
- return
- self.apply(apply_remove_weight_norm)
- seen = set()
- def apply_make_generation_fast_(module):
- if module != self and hasattr(module, 'make_generation_fast_') \
- and module not in seen:
- seen.add(module)
- module.make_generation_fast_(**kwargs)
- self.apply(apply_make_generation_fast_)
- def train(mode=True):
- if mode:
- raise RuntimeError('cannot train after make_generation_fast')
- # this model should no longer be used for training
- self.eval()
- self.train = train
- def prepare_for_onnx_export_(self, **kwargs):
- """Make model exportable via ONNX trace."""
- seen = set()
- def apply_prepare_for_onnx_export_(module):
- if module != self and hasattr(module, 'prepare_for_onnx_export_') \
- and module not in seen:
- seen.add(module)
- module.prepare_for_onnx_export_(**kwargs)
- self.apply(apply_prepare_for_onnx_export_)
- class FairseqModel(BaseFairseqModel):
- """Base class for encoder-decoder models.
- Args:
- encoder (FairseqEncoder): the encoder
- decoder (FairseqDecoder): the decoder
- """
- def __init__(self, encoder, decoder):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- assert isinstance(self.encoder, FairseqEncoder)
- assert isinstance(self.decoder, FairseqDecoder)
- def forward(self, src_tokens, src_lengths, prev_output_tokens):
- """
- Run the forward pass for an encoder-decoder model.
- First feed a batch of source tokens through the encoder. Then, feed the
- encoder output and previous decoder outputs (i.e., input feeding/teacher
- forcing) to the decoder to produce the next outputs::
- encoder_out = self.encoder(src_tokens, src_lengths)
- return self.decoder(prev_output_tokens, encoder_out)
- Args:
- src_tokens (LongTensor): tokens in the source language of shape
- `(batch, src_len)`
- src_lengths (LongTensor): source sentence lengths of shape `(batch)`
- prev_output_tokens (LongTensor): previous decoder outputs of shape
- `(batch, tgt_len)`, for input feeding/teacher forcing
- Returns:
- the decoder's output, typically of shape `(batch, tgt_len, vocab)`
- """
- encoder_out = self.encoder(src_tokens, src_lengths)
- decoder_out = self.decoder(prev_output_tokens, encoder_out)
- return decoder_out
- def max_positions(self):
- """Maximum length supported by the model."""
- return (self.encoder.max_positions(), self.decoder.max_positions())
- class FairseqMultiModel(BaseFairseqModel):
- """Base class for combining multiple encoder-decoder models."""
- def __init__(self, encoders, decoders):
- super().__init__()
- assert encoders.keys() == decoders.keys()
- self.keys = list(encoders.keys())
- for key in self.keys:
- assert isinstance(encoders[key], FairseqEncoder)
- assert isinstance(decoders[key], FairseqDecoder)
- self.models = nn.ModuleDict({
- key: FairseqModel(encoders[key], decoders[key])
- for key in self.keys
- })
- @staticmethod
- def build_shared_embeddings(
- dicts: Dict[str, Dictionary],
- langs: List[str],
- embed_dim: int,
- build_embedding: callable,
- pretrained_embed_path: Optional[str] = None,
- ):
- """
- Helper function to build shared embeddings for a set of languages after
- checking that all dicts corresponding to those languages are equivalent.
- Args:
- dicts: Dict of lang_id to its corresponding Dictionary
- langs: languages that we want to share embeddings for
- embed_dim: embedding dimension
- build_embedding: callable function to actually build the embedding
- pretrained_embed_path: Optional path to load pretrained embeddings
- """
- shared_dict = dicts[langs[0]]
- if any(dicts[lang] != shared_dict for lang in langs):
- raise ValueError(
- '--share-*-embeddings requires a joined dictionary: '
- '--share-encoder-embeddings requires a joined source '
- 'dictionary, --share-decoder-embeddings requires a joined '
- 'target dictionary, and --share-all-embeddings requires a '
- 'joint source + target dictionary.'
- )
- return build_embedding(
- shared_dict, embed_dim, pretrained_embed_path
- )
- def forward(self, src_tokens, src_lengths, prev_output_tokens):
- decoder_outs = {}
- for key in self.keys:
- encoder_out = self.models[key].encoder(src_tokens, src_lengths)
- decoder_outs[key] = self.models[key].decoder(prev_output_tokens, encoder_out)
- return decoder_outs
- def max_positions(self):
- """Maximum length supported by the model."""
- return {
- key: (self.models[key].encoder.max_positions(), self.models[key].decoder.max_positions())
- for key in self.keys
- }
- def max_decoder_positions(self):
- """Maximum length supported by the decoder."""
- return min(model.decoder.max_positions() for model in self.models.values())
- @property
- def encoder(self):
- return self.models[self.keys[0]].encoder
- @property
- def decoder(self):
- return self.models[self.keys[0]].decoder
- class FairseqLanguageModel(BaseFairseqModel):
- """Base class for decoder-only models.
- Args:
- decoder (FairseqDecoder): the decoder
- """
- def __init__(self, decoder):
- super().__init__()
- self.decoder = decoder
- assert isinstance(self.decoder, FairseqDecoder)
- def forward(self, src_tokens, src_lengths):
- """
- Run the forward pass for a decoder-only model.
- Feeds a batch of tokens through the decoder to predict the next tokens.
- Args:
- src_tokens (LongTensor): tokens on which to condition the decoder,
- of shape `(batch, tgt_len)`
- src_lengths (LongTensor): source sentence lengths of shape `(batch)`
- Returns:
- the decoder's output, typically of shape `(batch, seq_len, vocab)`
- """
- return self.decoder(src_tokens)
- def max_positions(self):
- """Maximum length supported by the model."""
- return self.decoder.max_positions()
- @property
- def supported_targets(self):
- return {'future'}
- def remove_head(self):
- """Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
- raise NotImplementedError()
|