Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

indexer.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. """
  2. Index module uses to map string entities (words, characters) to index
  3. The index can be used with an embedding
  4. """
  5. import string
  6. import pickle
  7. from src.data.constants import PAD_TOKEN, UNKNOWN_TOKEN
  8. class Indexer:
  9. def __init__(self):
  10. self.key2idx = {}
  11. self.idx2key = []
  12. def add(self, key):
  13. if key not in self.key2idx:
  14. self.key2idx[key] = len(self.idx2key)
  15. self.idx2key.append(key)
  16. return self.key2idx[key]
  17. def __getitem__(self, key):
  18. if isinstance(key, str):
  19. return self.key2idx[key]
  20. if isinstance(key, int):
  21. return self.idx2key[key]
  22. def save(self, f):
  23. with open(f, 'wt', encoding='utf-8') as fout:
  24. for index, key in enumerate(self.idx2key):
  25. fout.write(key + '\t' + str(index) + '\n')
  26. def load(self, f):
  27. with open(f, 'rt', encoding='utf-8') as fin:
  28. for line in fin:
  29. line = line.strip()
  30. if not line:
  31. continue
  32. key = line.split()[0]
  33. self.add(key)
  34. class Vocabulary(Indexer):
  35. def __init__(self):
  36. super().__init__()
  37. self.add(PAD_TOKEN)
  38. self.add(UNKNOWN_TOKEN)
  39. def __getitem__(self, key):
  40. if isinstance(key, str) and key not in self.key2idx:
  41. return self.key2idx[UNKNOWN_TOKEN]
  42. return super().__getitem__(key)
  43. class Charset(Indexer):
  44. def __init__(self):
  45. super().__init__()
  46. for char in string.printable[0:-6]:
  47. self.add(char)
  48. self.add(PAD_TOKEN)
  49. self.add(UNKNOWN_TOKEN)
  50. @staticmethod
  51. def type(char):
  52. if char in string.digits:
  53. return "Digits"
  54. if char in string.ascii_lowercase:
  55. return "Lower Case"
  56. if char in string.ascii_uppercase:
  57. return "Upper Case"
  58. if char in string.punctuation:
  59. return "Punctuation"
  60. return "Other"
  61. def __getitem__(self, key):
  62. if isinstance(key, str) and key not in self.key2idx:
  63. return self.key2idx[UNKNOWN_TOKEN]
  64. return super().__getitem__(key)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...