Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_token_block_dataset.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import unittest
  8. import torch
  9. from fairseq.data import TokenBlockDataset
  10. import tests.utils as test_utils
  11. class TestTokenBlockDataset(unittest.TestCase):
  12. def _build_dataset(self, data, **kwargs):
  13. sizes = [len(x) for x in data]
  14. underlying_ds = test_utils.TestDataset(data)
  15. return TokenBlockDataset(underlying_ds, sizes, **kwargs)
  16. def test_eos_break_mode(self):
  17. data = [
  18. torch.LongTensor([5, 4, 3, 2, 1]),
  19. torch.LongTensor([1]), # this should be filtered
  20. torch.LongTensor([8, 7, 6, 1]),
  21. ]
  22. ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
  23. self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
  24. self.assertEqual(ds[1].tolist(), [1])
  25. self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
  26. data = [
  27. torch.LongTensor([5, 4, 3, 2, 1]),
  28. torch.LongTensor([8, 7, 6, 1]),
  29. torch.LongTensor([1]), # this should be filtered
  30. ]
  31. ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
  32. self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
  33. self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
  34. self.assertEqual(ds[2].tolist(), [1])
  35. def test_block_break_mode(self):
  36. data = [
  37. torch.LongTensor([5, 4, 3, 2, 1]),
  38. torch.LongTensor([8, 7, 6, 1]),
  39. torch.LongTensor([9, 1]),
  40. ]
  41. ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
  42. self.assertEqual(ds[0].tolist(), [5, 4, 3])
  43. self.assertEqual(ds[1].tolist(), [2, 1, 8])
  44. self.assertEqual(ds[2].tolist(), [7, 6, 1])
  45. self.assertEqual(ds[3].tolist(), [9, 1])
  46. def test_complete_break_mode(self):
  47. data = [
  48. torch.LongTensor([5, 4, 3, 2, 1]),
  49. torch.LongTensor([8, 7, 6, 1]),
  50. torch.LongTensor([9, 1]),
  51. ]
  52. ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
  53. self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
  54. self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
  55. data = [
  56. torch.LongTensor([4, 3, 2, 1]),
  57. torch.LongTensor([5, 1]),
  58. torch.LongTensor([1]),
  59. torch.LongTensor([6, 1]),
  60. ]
  61. ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
  62. self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
  63. self.assertEqual(ds[1].tolist(), [5, 1, 1])
  64. self.assertEqual(ds[2].tolist(), [6, 1])
  65. if __name__ == "__main__":
  66. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...