1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import contextlib
- import os
- import numpy as np
- def infer_language_pair(path):
- """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
- src, dst = None, None
- for filename in os.listdir(path):
- parts = filename.split('.')
- if len(parts) >= 3 and len(parts[1].split('-')) == 2:
- return parts[1].split('-')
- return src, dst
- def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
- """Convert a list of 1d tensors into a padded 2d tensor."""
- size = max(v.size(0) for v in values)
- res = values[0].new(len(values), size).fill_(pad_idx)
- def copy_tensor(src, dst):
- assert dst.numel() == src.numel()
- if move_eos_to_beginning:
- assert src[-1] == eos_idx
- dst[0] = eos_idx
- dst[1:] = src[:-1]
- else:
- dst.copy_(src)
- for i, v in enumerate(values):
- copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
- return res
- @contextlib.contextmanager
- def numpy_seed(seed):
- """Context manager which seeds the NumPy PRNG with the specified seed and
- restores the state afterward"""
- if seed is None:
- yield
- return
- state = np.random.get_state()
- np.random.seed(seed)
- try:
- yield
- finally:
- np.random.set_state(state)
- def collect_filtered(function, iterable, filtered):
- """
- Similar to :func:`filter` but collects filtered elements in ``filtered``.
- Args:
- function (callable): function that returns ``False`` for elements that
- should be filtered
- iterable (iterable): iterable to filter
- filtered (list): list to store filtered elements
- """
- for el in iterable:
- if function(el):
- yield el
- else:
- filtered.append(el)
- def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
- """
- Filter indices based on their size.
- Args:
- indices (List[int]): ordered list of dataset indices
- size_fn (callable): function that returns the size of a given index
- max_positions (tuple): filter elements larger than this size.
- Comparisons are done component-wise.
- raise_exception (bool, optional): if ``True``, raise an exception if
- any elements are filtered (default: False).
- """
- def check_size(idx):
- if isinstance(max_positions, float) or isinstance(max_positions, int):
- return size_fn(idx) <= max_positions
- elif isinstance(max_positions, dict):
- idx_size = size_fn(idx)
- assert isinstance(idx_size, dict)
- intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
- return all(
- all(a is None or b is None or a <= b
- for a, b in zip(idx_size[key], max_positions[key]))
- for key in intersect_keys
- )
- else:
- return all(a is None or b is None or a <= b
- for a, b in zip(size_fn(idx), max_positions))
- ignored = []
- itr = collect_filtered(check_size, indices, ignored)
- for idx in itr:
- if len(ignored) > 0 and raise_exception:
- raise Exception((
- 'Size of sample #{} is invalid (={}) since max_positions={}, '
- 'skip this example with --skip-invalid-size-inputs-valid-test'
- ).format(ignored[0], size_fn(ignored[0]), max_positions))
- yield idx
- if len(ignored) > 0:
- print((
- '| WARNING: {} samples have invalid sizes and will be skipped, '
- 'max_positions={}, first few sample ids={}'
- ).format(len(ignored), max_positions, ignored[:10]))
- def batch_by_size(
- indices, num_tokens_fn, max_tokens=None, max_sentences=None,
- required_batch_size_multiple=1,
- ):
- """
- Yield mini-batches of indices bucketed by size. Batches may contain
- sequences of different lengths.
- Args:
- indices (List[int]): ordered list of dataset indices
- num_tokens_fn (callable): function that returns the number of tokens at
- a given index
- max_tokens (int, optional): max number of tokens in each batch
- (default: None).
- max_sentences (int, optional): max number of sentences in each
- batch (default: None).
- required_batch_size_multiple (int, optional): require batch size to
- be a multiple of N (default: 1).
- """
- max_tokens = max_tokens if max_tokens is not None else float('Inf')
- max_sentences = max_sentences if max_sentences is not None else float('Inf')
- bsz_mult = required_batch_size_multiple
- batch = []
- def is_batch_full(num_tokens):
- if len(batch) == 0:
- return False
- if len(batch) == max_sentences:
- return True
- if num_tokens > max_tokens:
- return True
- return False
- sample_len = 0
- sample_lens = []
- for idx in indices:
- sample_lens.append(num_tokens_fn(idx))
- sample_len = max(sample_len, sample_lens[-1])
- assert sample_len <= max_tokens, f"sentence at index {idx} exceeds max_tokens limit!"
- num_tokens = (len(batch) + 1) * sample_len
- if is_batch_full(num_tokens):
- mod_len = max(
- bsz_mult * (len(batch) // bsz_mult),
- len(batch) % bsz_mult,
- )
- yield batch[:mod_len]
- batch = batch[mod_len:]
- sample_lens = sample_lens[mod_len:]
- sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
- batch.append(idx)
- if len(batch) > 0:
- yield batch
|