Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

encoder.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  1. """Byte pair encoding utilities"""
  2. import os
  3. import json
  4. import regex as re
  5. from functools import lru_cache
  6. @lru_cache()
  7. def bytes_to_unicode():
  8. """
  9. Returns list of utf-8 byte and a corresponding list of unicode strings.
  10. The reversible bpe codes work on unicode strings.
  11. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  12. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  13. This is a signficant percentage of your normal, say, 32K bpe vocab.
  14. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  15. And avoids mapping to whitespace/control characters the bpe code barfs on.
  16. """
  17. bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
  18. cs = bs[:]
  19. n = 0
  20. for b in range(2**8):
  21. if b not in bs:
  22. bs.append(b)
  23. cs.append(2**8+n)
  24. n += 1
  25. cs = [chr(n) for n in cs]
  26. return dict(zip(bs, cs))
  27. def get_pairs(word):
  28. """Return set of symbol pairs in a word.
  29. Word is represented as tuple of symbols (symbols being variable-length strings).
  30. """
  31. pairs = set()
  32. prev_char = word[0]
  33. for char in word[1:]:
  34. pairs.add((prev_char, char))
  35. prev_char = char
  36. return pairs
  37. class Encoder:
  38. def __init__(self, encoder, bpe_merges, errors='replace'):
  39. self.encoder = encoder
  40. self.decoder = {v:k for k,v in self.encoder.items()}
  41. self.errors = errors # how to handle errors in decoding
  42. self.byte_encoder = bytes_to_unicode()
  43. self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
  44. self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
  45. self.cache = {}
  46. # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
  47. self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
  48. def bpe(self, token):
  49. if token in self.cache:
  50. return self.cache[token]
  51. word = tuple(token)
  52. pairs = get_pairs(word)
  53. if not pairs:
  54. return token
  55. while True:
  56. bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
  57. if bigram not in self.bpe_ranks:
  58. break
  59. first, second = bigram
  60. new_word = []
  61. i = 0
  62. while i < len(word):
  63. try:
  64. j = word.index(first, i)
  65. new_word.extend(word[i:j])
  66. i = j
  67. except:
  68. new_word.extend(word[i:])
  69. break
  70. if word[i] == first and i < len(word)-1 and word[i+1] == second:
  71. new_word.append(first+second)
  72. i += 2
  73. else:
  74. new_word.append(word[i])
  75. i += 1
  76. new_word = tuple(new_word)
  77. word = new_word
  78. if len(word) == 1:
  79. break
  80. else:
  81. pairs = get_pairs(word)
  82. word = ' '.join(word)
  83. self.cache[token] = word
  84. return word
  85. def encode(self, text):
  86. bpe_tokens = []
  87. for token in re.findall(self.pat, text):
  88. token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
  89. bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
  90. return bpe_tokens
  91. def decode(self, tokens):
  92. text = ''.join([self.decoder[token] for token in tokens])
  93. text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
  94. return text
  95. def get_encoder(model_name, models_dir):
  96. with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
  97. encoder = json.load(f)
  98. with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
  99. bpe_data = f.read()
  100. bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
  101. return Encoder(
  102. encoder=encoder,
  103. bpe_merges=bpe_merges,
  104. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...