Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_sequence_scorer.py 3.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import argparse
  8. import unittest
  9. import torch
  10. from fairseq.sequence_scorer import SequenceScorer
  11. import tests.utils as test_utils
  12. class TestSequenceScorer(unittest.TestCase):
  13. def test_sequence_scorer(self):
  14. # construct dummy dictionary
  15. d = test_utils.dummy_dictionary(vocab_size=2)
  16. self.assertEqual(d.pad(), 1)
  17. self.assertEqual(d.eos(), 2)
  18. self.assertEqual(d.unk(), 3)
  19. eos = d.eos()
  20. w1 = 4
  21. w2 = 5
  22. # construct dataloader
  23. data = [
  24. {
  25. 'source': torch.LongTensor([w1, w2, eos]),
  26. 'target': torch.LongTensor([w1, w2, w1, eos]),
  27. },
  28. {
  29. 'source': torch.LongTensor([w2, eos]),
  30. 'target': torch.LongTensor([w2, w1, eos]),
  31. },
  32. {
  33. 'source': torch.LongTensor([w2, eos]),
  34. 'target': torch.LongTensor([w2, eos]),
  35. },
  36. ]
  37. data_itr = test_utils.dummy_dataloader(data)
  38. # specify expected output probabilities
  39. args = argparse.Namespace()
  40. unk = 0.
  41. args.beam_probs = [
  42. # step 0:
  43. torch.FloatTensor([
  44. # eos w1 w2
  45. [0.0, unk, 0.6, 0.4], # sentence 1
  46. [0.0, unk, 0.4, 0.6], # sentence 2
  47. [0.0, unk, 0.7, 0.3], # sentence 3
  48. ]),
  49. # step 1:
  50. torch.FloatTensor([
  51. # eos w1 w2
  52. [0.0, unk, 0.2, 0.7], # sentence 1
  53. [0.0, unk, 0.8, 0.2], # sentence 2
  54. [0.7, unk, 0.1, 0.2], # sentence 3
  55. ]),
  56. # step 2:
  57. torch.FloatTensor([
  58. # eos w1 w2
  59. [0.10, unk, 0.50, 0.4], # sentence 1
  60. [0.15, unk, 0.15, 0.7], # sentence 2
  61. [0.00, unk, 0.00, 0.0], # sentence 3
  62. ]),
  63. # step 3:
  64. torch.FloatTensor([
  65. # eos w1 w2
  66. [0.9, unk, 0.05, 0.05], # sentence 1
  67. [0.0, unk, 0.00, 0.0], # sentence 2
  68. [0.0, unk, 0.00, 0.0], # sentence 3
  69. ]),
  70. ]
  71. expected_scores = [
  72. [0.6, 0.7, 0.5, 0.9], # sentence 1
  73. [0.6, 0.8, 0.15], # sentence 2
  74. [0.3, 0.7], # sentence 3
  75. ]
  76. task = test_utils.TestTranslationTask.setup_task(args, d, d)
  77. model = task.build_model(args)
  78. scorer = SequenceScorer([model], task.target_dictionary)
  79. for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
  80. self.assertHypoTokens(hypos[0], data[id]['target'])
  81. self.assertHypoScore(hypos[0], expected_scores[id])
  82. def assertHypoTokens(self, hypo, tokens):
  83. self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
  84. def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
  85. pos_scores = torch.FloatTensor(pos_probs).log()
  86. self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
  87. self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
  88. score = pos_scores.sum()
  89. if normalized:
  90. score /= pos_scores.numel()**lenpen
  91. self.assertLess(abs(score - hypo['score']), 1e-6)
  92. def assertAlmostEqual(self, t1, t2):
  93. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  94. self.assertLess((t1 - t2).abs().max(), 1e-4)
  95. def assertTensorEqual(self, t1, t2):
  96. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  97. self.assertEqual(t1.ne(t2).long().sum(), 0)
  98. if __name__ == '__main__':
  99. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...