Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_dictionary.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import tempfile
  8. import unittest
  9. import torch
  10. from fairseq.data import Dictionary
  11. from fairseq.tokenizer import Tokenizer
  12. class TestDictionary(unittest.TestCase):
  13. def test_finalize(self):
  14. txt = [
  15. 'A B C D',
  16. 'B C D',
  17. 'C D',
  18. 'D',
  19. ]
  20. ref_ids1 = list(map(torch.IntTensor, [
  21. [4, 5, 6, 7, 2],
  22. [5, 6, 7, 2],
  23. [6, 7, 2],
  24. [7, 2],
  25. ]))
  26. ref_ids2 = list(map(torch.IntTensor, [
  27. [7, 6, 5, 4, 2],
  28. [6, 5, 4, 2],
  29. [5, 4, 2],
  30. [4, 2],
  31. ]))
  32. # build dictionary
  33. d = Dictionary()
  34. for line in txt:
  35. Tokenizer.tokenize(line, d, add_if_not_exist=True)
  36. def get_ids(dictionary):
  37. ids = []
  38. for line in txt:
  39. ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
  40. return ids
  41. def assertMatch(ids, ref_ids):
  42. for toks, ref_toks in zip(ids, ref_ids):
  43. self.assertEqual(toks.size(), ref_toks.size())
  44. self.assertEqual(0, (toks != ref_toks).sum().item())
  45. ids = get_ids(d)
  46. assertMatch(ids, ref_ids1)
  47. # check finalized dictionary
  48. d.finalize()
  49. finalized_ids = get_ids(d)
  50. assertMatch(finalized_ids, ref_ids2)
  51. # write to disk and reload
  52. with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
  53. d.save(tmp_dict.name)
  54. d = Dictionary.load(tmp_dict.name)
  55. reload_ids = get_ids(d)
  56. assertMatch(reload_ids, ref_ids2)
  57. assertMatch(finalized_ids, reload_ids)
  58. if __name__ == '__main__':
  59. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...