Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data_utils.py 5.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import contextlib
  8. import os
  9. import numpy as np
  10. def infer_language_pair(path):
  11. """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
  12. src, dst = None, None
  13. for filename in os.listdir(path):
  14. parts = filename.split('.')
  15. if len(parts) >= 3 and len(parts[1].split('-')) == 2:
  16. return parts[1].split('-')
  17. return src, dst
  18. def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
  19. """Convert a list of 1d tensors into a padded 2d tensor."""
  20. size = max(v.size(0) for v in values)
  21. res = values[0].new(len(values), size).fill_(pad_idx)
  22. def copy_tensor(src, dst):
  23. assert dst.numel() == src.numel()
  24. if move_eos_to_beginning:
  25. assert src[-1] == eos_idx
  26. dst[0] = eos_idx
  27. dst[1:] = src[:-1]
  28. else:
  29. dst.copy_(src)
  30. for i, v in enumerate(values):
  31. copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
  32. return res
  33. @contextlib.contextmanager
  34. def numpy_seed(seed):
  35. """Context manager which seeds the NumPy PRNG with the specified seed and
  36. restores the state afterward"""
  37. if seed is None:
  38. yield
  39. return
  40. state = np.random.get_state()
  41. np.random.seed(seed)
  42. try:
  43. yield
  44. finally:
  45. np.random.set_state(state)
  46. def collect_filtered(function, iterable, filtered):
  47. """
  48. Similar to :func:`filter` but collects filtered elements in ``filtered``.
  49. Args:
  50. function (callable): function that returns ``False`` for elements that
  51. should be filtered
  52. iterable (iterable): iterable to filter
  53. filtered (list): list to store filtered elements
  54. """
  55. for el in iterable:
  56. if function(el):
  57. yield el
  58. else:
  59. filtered.append(el)
  60. def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
  61. """
  62. Filter indices based on their size.
  63. Args:
  64. indices (List[int]): ordered list of dataset indices
  65. size_fn (callable): function that returns the size of a given index
  66. max_positions (tuple): filter elements larger than this size.
  67. Comparisons are done component-wise.
  68. raise_exception (bool, optional): if ``True``, raise an exception if
  69. any elements are filtered (default: False).
  70. """
  71. def check_size(idx):
  72. if isinstance(max_positions, float) or isinstance(max_positions, int):
  73. return size_fn(idx) <= max_positions
  74. elif isinstance(max_positions, dict):
  75. idx_size = size_fn(idx)
  76. assert isinstance(idx_size, dict)
  77. intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
  78. return all(
  79. all(a is None or b is None or a <= b
  80. for a, b in zip(idx_size[key], max_positions[key]))
  81. for key in intersect_keys
  82. )
  83. else:
  84. return all(a is None or b is None or a <= b
  85. for a, b in zip(size_fn(idx), max_positions))
  86. ignored = []
  87. itr = collect_filtered(check_size, indices, ignored)
  88. for idx in itr:
  89. if len(ignored) > 0 and raise_exception:
  90. raise Exception((
  91. 'Size of sample #{} is invalid (={}) since max_positions={}, '
  92. 'skip this example with --skip-invalid-size-inputs-valid-test'
  93. ).format(ignored[0], size_fn(ignored[0]), max_positions))
  94. yield idx
  95. if len(ignored) > 0:
  96. print((
  97. '| WARNING: {} samples have invalid sizes and will be skipped, '
  98. 'max_positions={}, first few sample ids={}'
  99. ).format(len(ignored), max_positions, ignored[:10]))
  100. def batch_by_size(
  101. indices, num_tokens_fn, max_tokens=None, max_sentences=None,
  102. required_batch_size_multiple=1,
  103. ):
  104. """
  105. Yield mini-batches of indices bucketed by size. Batches may contain
  106. sequences of different lengths.
  107. Args:
  108. indices (List[int]): ordered list of dataset indices
  109. num_tokens_fn (callable): function that returns the number of tokens at
  110. a given index
  111. max_tokens (int, optional): max number of tokens in each batch
  112. (default: None).
  113. max_sentences (int, optional): max number of sentences in each
  114. batch (default: None).
  115. required_batch_size_multiple (int, optional): require batch size to
  116. be a multiple of N (default: 1).
  117. """
  118. max_tokens = max_tokens if max_tokens is not None else float('Inf')
  119. max_sentences = max_sentences if max_sentences is not None else float('Inf')
  120. bsz_mult = required_batch_size_multiple
  121. batch = []
  122. def is_batch_full(num_tokens):
  123. if len(batch) == 0:
  124. return False
  125. if len(batch) == max_sentences:
  126. return True
  127. if num_tokens > max_tokens:
  128. return True
  129. return False
  130. sample_len = 0
  131. sample_lens = []
  132. for idx in indices:
  133. sample_lens.append(num_tokens_fn(idx))
  134. sample_len = max(sample_len, sample_lens[-1])
  135. assert sample_len <= max_tokens, f"sentence at index {idx} exceeds max_tokens limit!"
  136. num_tokens = (len(batch) + 1) * sample_len
  137. if is_batch_full(num_tokens):
  138. mod_len = max(
  139. bsz_mult * (len(batch) // bsz_mult),
  140. len(batch) % bsz_mult,
  141. )
  142. yield batch[:mod_len]
  143. batch = batch[mod_len:]
  144. sample_lens = sample_lens[mod_len:]
  145. sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
  146. batch.append(idx)
  147. if len(batch) > 0:
  148. yield batch
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...