Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_backtranslation_dataset.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import unittest
  8. import torch
  9. from fairseq.data import (
  10. BacktranslationDataset,
  11. LanguagePairDataset,
  12. TransformEosDataset,
  13. )
  14. from fairseq.sequence_generator import SequenceGenerator
  15. import tests.utils as test_utils
  16. class TestBacktranslationDataset(unittest.TestCase):
  17. def setUp(self):
  18. self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
  19. test_utils.sequence_generator_setup()
  20. )
  21. dummy_src_samples = self.src_tokens
  22. self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
  23. self.cuda = torch.cuda.is_available()
  24. def _backtranslation_dataset_helper(
  25. self, remove_eos_from_input_src, remove_eos_from_output_src,
  26. ):
  27. tgt_dataset = LanguagePairDataset(
  28. src=self.tgt_dataset,
  29. src_sizes=self.tgt_dataset.sizes,
  30. src_dict=self.tgt_dict,
  31. tgt=None,
  32. tgt_sizes=None,
  33. tgt_dict=None,
  34. )
  35. generator = SequenceGenerator(
  36. models=[self.model],
  37. tgt_dict=self.tgt_dict,
  38. beam_size=2,
  39. unk_penalty=0,
  40. sampling=False,
  41. )
  42. if self.cuda:
  43. generator.cuda()
  44. backtranslation_dataset = BacktranslationDataset(
  45. tgt_dataset=TransformEosDataset(
  46. dataset=tgt_dataset,
  47. eos=self.tgt_dict.eos(),
  48. # remove eos from the input src
  49. remove_eos_from_src=remove_eos_from_input_src,
  50. ),
  51. backtranslation_fn=generator.generate,
  52. max_len_a=0,
  53. max_len_b=200,
  54. output_collater=TransformEosDataset(
  55. dataset=tgt_dataset,
  56. eos=self.tgt_dict.eos(),
  57. # if we remove eos from the input src, then we need to add it
  58. # back to the output tgt
  59. append_eos_to_tgt=remove_eos_from_input_src,
  60. remove_eos_from_src=remove_eos_from_output_src,
  61. ).collater,
  62. cuda=self.cuda,
  63. )
  64. dataloader = torch.utils.data.DataLoader(
  65. backtranslation_dataset,
  66. batch_size=2,
  67. collate_fn=backtranslation_dataset.collater,
  68. )
  69. backtranslation_batch_result = next(iter(dataloader))
  70. eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
  71. # Note that we sort by src_lengths and add left padding, so actually
  72. # ids will look like: [1, 0]
  73. expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
  74. if remove_eos_from_output_src:
  75. expected_src = expected_src[:, :-1]
  76. expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
  77. generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
  78. tgt_tokens = backtranslation_batch_result["target"]
  79. self.assertTensorEqual(expected_src, generated_src)
  80. self.assertTensorEqual(expected_tgt, tgt_tokens)
  81. def test_backtranslation_dataset_no_eos_in_output_src(self):
  82. self._backtranslation_dataset_helper(
  83. remove_eos_from_input_src=False, remove_eos_from_output_src=True,
  84. )
  85. def test_backtranslation_dataset_with_eos_in_output_src(self):
  86. self._backtranslation_dataset_helper(
  87. remove_eos_from_input_src=False, remove_eos_from_output_src=False,
  88. )
  89. def test_backtranslation_dataset_no_eos_in_input_src(self):
  90. self._backtranslation_dataset_helper(
  91. remove_eos_from_input_src=True, remove_eos_from_output_src=False,
  92. )
  93. def assertTensorEqual(self, t1, t2):
  94. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  95. self.assertEqual(t1.ne(t2).long().sum(), 0)
  96. if __name__ == "__main__":
  97. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...