Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

language_pair_dataset.py 9.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import numpy as np
  8. import torch
  9. from fairseq import utils
  10. from . import data_utils, FairseqDataset
  11. def collate(
  12. samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
  13. input_feeding=True,
  14. ):
  15. if len(samples) == 0:
  16. return {}
  17. def merge(key, left_pad, move_eos_to_beginning=False):
  18. return data_utils.collate_tokens(
  19. [s[key] for s in samples],
  20. pad_idx, eos_idx, left_pad, move_eos_to_beginning,
  21. )
  22. id = torch.LongTensor([s['id'] for s in samples])
  23. src_tokens = merge('source', left_pad=left_pad_source)
  24. # sort by descending source length
  25. src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
  26. src_lengths, sort_order = src_lengths.sort(descending=True)
  27. id = id.index_select(0, sort_order)
  28. src_tokens = src_tokens.index_select(0, sort_order)
  29. prev_output_tokens = None
  30. target = None
  31. if samples[0].get('target', None) is not None:
  32. target = merge('target', left_pad=left_pad_target)
  33. target = target.index_select(0, sort_order)
  34. ntokens = sum(len(s['target']) for s in samples)
  35. if input_feeding:
  36. # we create a shifted version of targets for feeding the
  37. # previous output token(s) into the next decoder step
  38. prev_output_tokens = merge(
  39. 'target',
  40. left_pad=left_pad_target,
  41. move_eos_to_beginning=True,
  42. )
  43. prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
  44. else:
  45. ntokens = sum(len(s['source']) for s in samples)
  46. batch = {
  47. 'id': id,
  48. 'nsentences': len(samples),
  49. 'ntokens': ntokens,
  50. 'net_input': {
  51. 'src_tokens': src_tokens,
  52. 'src_lengths': src_lengths,
  53. },
  54. 'target': target,
  55. }
  56. if prev_output_tokens is not None:
  57. batch['net_input']['prev_output_tokens'] = prev_output_tokens
  58. return batch
  59. class LanguagePairDataset(FairseqDataset):
  60. """
  61. A pair of torch.utils.data.Datasets.
  62. Args:
  63. src (torch.utils.data.Dataset): source dataset to wrap
  64. src_sizes (List[int]): source sentence lengths
  65. src_dict (~fairseq.data.Dictionary): source vocabulary
  66. tgt (torch.utils.data.Dataset, optional): target dataset to wrap
  67. tgt_sizes (List[int], optional): target sentence lengths
  68. tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
  69. left_pad_source (bool, optional): pad source tensors on the left side
  70. (default: True).
  71. left_pad_target (bool, optional): pad target tensors on the left side
  72. (default: False).
  73. max_source_positions (int, optional): max number of tokens in the
  74. source sentence (default: 1024).
  75. max_target_positions (int, optional): max number of tokens in the
  76. target sentence (default: 1024).
  77. shuffle (bool, optional): shuffle dataset elements before batching
  78. (default: True).
  79. input_feeding (bool, optional): create a shifted version of the targets
  80. to be passed into the model for input feeding/teacher forcing
  81. (default: True).
  82. remove_eos_from_source (bool, optional): if set, removes eos from end
  83. of source if it's present (default: False).
  84. append_eos_to_target (bool, optional): if set, appends eos to end of
  85. target if it's absent (default: False).
  86. """
  87. def __init__(
  88. self, src, src_sizes, src_dict,
  89. tgt=None, tgt_sizes=None, tgt_dict=None,
  90. left_pad_source=True, left_pad_target=False,
  91. max_source_positions=1024, max_target_positions=1024,
  92. shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False,
  93. ):
  94. if tgt_dict is not None:
  95. assert src_dict.pad() == tgt_dict.pad()
  96. assert src_dict.eos() == tgt_dict.eos()
  97. assert src_dict.unk() == tgt_dict.unk()
  98. self.src = src
  99. self.tgt = tgt
  100. self.src_sizes = np.array(src_sizes)
  101. self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
  102. self.src_dict = src_dict
  103. self.tgt_dict = tgt_dict
  104. self.left_pad_source = left_pad_source
  105. self.left_pad_target = left_pad_target
  106. self.max_source_positions = max_source_positions
  107. self.max_target_positions = max_target_positions
  108. self.shuffle = shuffle
  109. self.input_feeding = input_feeding
  110. self.remove_eos_from_source = remove_eos_from_source
  111. self.append_eos_to_target = append_eos_to_target
  112. def __getitem__(self, index):
  113. tgt_item = self.tgt[index] if self.tgt is not None else None
  114. src_item = self.src[index]
  115. # Append EOS to end of tgt sentence if it does not have an EOS and remove
  116. # EOS from end of src sentence if it exists. This is useful when we use
  117. # use existing datasets for opposite directions i.e., when we want to
  118. # use tgt_dataset as src_dataset and vice versa
  119. if self.append_eos_to_target:
  120. eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
  121. if self.tgt and self.tgt[index][-1] != eos:
  122. tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
  123. if self.remove_eos_from_source:
  124. eos = self.src_dict.eos()
  125. if self.src[index][-1] == eos:
  126. src_item = self.src[index][:-1]
  127. return {
  128. 'id': index,
  129. 'source': src_item,
  130. 'target': tgt_item,
  131. }
  132. def __len__(self):
  133. return len(self.src)
  134. def collater(self, samples):
  135. """Merge a list of samples to form a mini-batch.
  136. Args:
  137. samples (List[dict]): samples to collate
  138. Returns:
  139. dict: a mini-batch with the following keys:
  140. - `id` (LongTensor): example IDs in the original input order
  141. - `ntokens` (int): total number of tokens in the batch
  142. - `net_input` (dict): the input to the Model, containing keys:
  143. - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
  144. the source sentence of shape `(bsz, src_len)`. Padding will
  145. appear on the left if *left_pad_source* is ``True``.
  146. - `src_lengths` (LongTensor): 1D Tensor of the unpadded
  147. lengths of each source sentence of shape `(bsz)`
  148. - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
  149. tokens in the target sentence, shifted right by one position
  150. for input feeding/teacher forcing, of shape `(bsz,
  151. tgt_len)`. This key will not be present if *input_feeding*
  152. is ``False``. Padding will appear on the left if
  153. *left_pad_target* is ``True``.
  154. - `target` (LongTensor): a padded 2D Tensor of tokens in the
  155. target sentence of shape `(bsz, tgt_len)`. Padding will appear
  156. on the left if *left_pad_target* is ``True``.
  157. """
  158. return collate(
  159. samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
  160. left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
  161. input_feeding=self.input_feeding,
  162. )
  163. def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
  164. """Return a dummy batch with a given number of tokens."""
  165. src_len, tgt_len = utils.resolve_max_positions(
  166. (src_len, tgt_len),
  167. max_positions,
  168. (self.max_source_positions, self.max_target_positions),
  169. )
  170. bsz = max(num_tokens // max(src_len, tgt_len), 1)
  171. return self.collater([
  172. {
  173. 'id': i,
  174. 'source': self.src_dict.dummy_sentence(src_len),
  175. 'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
  176. }
  177. for i in range(bsz)
  178. ])
  179. def num_tokens(self, index):
  180. """Return the number of tokens in a sample. This value is used to
  181. enforce ``--max-tokens`` during batching."""
  182. return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
  183. def size(self, index):
  184. """Return an example's size as a float or tuple. This value is used when
  185. filtering a dataset with ``--max-positions``."""
  186. return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
  187. def ordered_indices(self):
  188. """Return an ordered list of indices. Batches will be constructed based
  189. on this order."""
  190. if self.shuffle:
  191. indices = np.random.permutation(len(self))
  192. else:
  193. indices = np.arange(len(self))
  194. if self.tgt_sizes is not None:
  195. indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
  196. return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
  197. @property
  198. def supports_prefetch(self):
  199. return (
  200. getattr(self.src, 'supports_prefetch', False)
  201. and getattr(self.tgt, 'supports_prefetch', False)
  202. )
  203. def prefetch(self, indices):
  204. self.src.prefetch(indices)
  205. self.tgt.prefetch(indices)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...