Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_character_token_embedder.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch
  8. import unittest
  9. from fairseq.data import Dictionary
  10. from fairseq.modules import CharacterTokenEmbedder
  11. class TestCharacterTokenEmbedder(unittest.TestCase):
  12. def test_character_token_embedder(self):
  13. vocab = Dictionary()
  14. vocab.add_symbol('hello')
  15. vocab.add_symbol('there')
  16. embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2)
  17. test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']]
  18. max_len = max(len(s) for s in test_sents)
  19. input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad())
  20. for i in range(len(test_sents)):
  21. input[i][0] = vocab.eos()
  22. for j in range(len(test_sents[i])):
  23. input[i][j + 1] = vocab.index(test_sents[i][j])
  24. input[i][j + 2] = vocab.eos()
  25. embs = embedder(input)
  26. assert embs.size() == (len(test_sents), max_len + 2, 5)
  27. self.assertAlmostEqual(embs[0][0], embs[1][0])
  28. self.assertAlmostEqual(embs[0][0], embs[0][-1])
  29. self.assertAlmostEqual(embs[0][1], embs[2][1])
  30. self.assertAlmostEqual(embs[0][3], embs[1][1])
  31. embs.sum().backward()
  32. assert embedder.char_embeddings.weight.grad is not None
  33. def assertAlmostEqual(self, t1, t2):
  34. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  35. self.assertLess((t1 - t2).abs().max(), 1e-6)
  36. if __name__ == '__main__':
  37. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...