Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

collation.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  1. from pathlib import Path
  2. from typing import List, Tuple
  3. import numpy as np
  4. import torch
  5. class TextTokenCollater:
  6. """Collate list of text tokens
  7. Map sentences to integers. Sentences are padded to equal length.
  8. Beginning and end-of-sequence symbols can be added.
  9. Example:
  10. >>> token_collater = TextTokenCollater(text_tokens)
  11. >>> tokens_batch, tokens_lens = token_collater(text)
  12. Returns:
  13. tokens_batch: IntTensor of shape (B, L)
  14. B: batch dimension, number of input sentences
  15. L: length of the longest sentence
  16. tokens_lens: IntTensor of shape (B,)
  17. Length of each sentence after adding <eos> and <bos>
  18. but before padding.
  19. """
  20. def __init__(
  21. self,
  22. text_tokens: List[str],
  23. add_eos: bool = True,
  24. add_bos: bool = True,
  25. pad_symbol: str = "<pad>",
  26. bos_symbol: str = "<bos>",
  27. eos_symbol: str = "<eos>",
  28. ):
  29. self.pad_symbol = pad_symbol
  30. self.add_eos = add_eos
  31. self.add_bos = add_bos
  32. self.bos_symbol = bos_symbol
  33. self.eos_symbol = eos_symbol
  34. unique_tokens = (
  35. [pad_symbol]
  36. + ([bos_symbol] if add_bos else [])
  37. + ([eos_symbol] if add_eos else [])
  38. + sorted(text_tokens)
  39. )
  40. self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
  41. self.idx2token = [token for token in unique_tokens]
  42. def index(
  43. self, tokens_list: List[str]
  44. ) -> Tuple[torch.Tensor, torch.Tensor]:
  45. seqs, seq_lens = [], []
  46. for tokens in tokens_list:
  47. assert (
  48. all([True if s in self.token2idx else False for s in tokens])
  49. is True
  50. )
  51. seq = (
  52. ([self.bos_symbol] if self.add_bos else [])
  53. + list(tokens)
  54. + ([self.eos_symbol] if self.add_eos else [])
  55. )
  56. seqs.append(seq)
  57. seq_lens.append(len(seq))
  58. max_len = max(seq_lens)
  59. for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
  60. seq.extend([self.pad_symbol] * (max_len - seq_len))
  61. tokens = torch.from_numpy(
  62. np.array(
  63. [[self.token2idx[token] for token in seq] for seq in seqs],
  64. dtype=np.int64,
  65. )
  66. )
  67. tokens_lens = torch.IntTensor(seq_lens)
  68. return tokens, tokens_lens
  69. def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
  70. tokens_seqs = [[p for p in text] for text in texts]
  71. max_len = len(max(tokens_seqs, key=len))
  72. seqs = [
  73. ([self.bos_symbol] if self.add_bos else [])
  74. + list(seq)
  75. + ([self.eos_symbol] if self.add_eos else [])
  76. + [self.pad_symbol] * (max_len - len(seq))
  77. for seq in tokens_seqs
  78. ]
  79. tokens_batch = torch.from_numpy(
  80. np.array(
  81. [seq for seq in seqs],
  82. dtype=np.int64,
  83. )
  84. )
  85. tokens_lens = torch.IntTensor(
  86. [
  87. len(seq) + int(self.add_eos) + int(self.add_bos)
  88. for seq in tokens_seqs
  89. ]
  90. )
  91. return tokens_batch, tokens_lens
  92. def get_text_token_collater() -> TextTokenCollater:
  93. collater = TextTokenCollater(
  94. ['0'], add_bos=False, add_eos=False
  95. )
  96. return collater
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...