1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import tempfile
- import unittest
- import torch
- from fairseq.data import Dictionary
- from fairseq.tokenizer import Tokenizer
- class TestDictionary(unittest.TestCase):
- def test_finalize(self):
- txt = [
- 'A B C D',
- 'B C D',
- 'C D',
- 'D',
- ]
- ref_ids1 = list(map(torch.IntTensor, [
- [4, 5, 6, 7, 2],
- [5, 6, 7, 2],
- [6, 7, 2],
- [7, 2],
- ]))
- ref_ids2 = list(map(torch.IntTensor, [
- [7, 6, 5, 4, 2],
- [6, 5, 4, 2],
- [5, 4, 2],
- [4, 2],
- ]))
- # build dictionary
- d = Dictionary()
- for line in txt:
- Tokenizer.tokenize(line, d, add_if_not_exist=True)
- def get_ids(dictionary):
- ids = []
- for line in txt:
- ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
- return ids
- def assertMatch(ids, ref_ids):
- for toks, ref_toks in zip(ids, ref_ids):
- self.assertEqual(toks.size(), ref_toks.size())
- self.assertEqual(0, (toks != ref_toks).sum().item())
- ids = get_ids(d)
- assertMatch(ids, ref_ids1)
- # check finalized dictionary
- d.finalize()
- finalized_ids = get_ids(d)
- assertMatch(finalized_ids, ref_ids2)
- # write to disk and reload
- with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
- d.save(tmp_dict.name)
- d = Dictionary.load(tmp_dict.name)
- reload_ids = get_ids(d)
- assertMatch(reload_ids, ref_ids2)
- assertMatch(finalized_ids, reload_ids)
- if __name__ == '__main__':
- unittest.main()
|