Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_train.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import contextlib
  8. from io import StringIO
  9. import unittest
  10. from unittest.mock import MagicMock, patch
  11. import torch
  12. from fairseq import data
  13. import train
  14. def mock_trainer(epoch, num_updates, iterations_in_epoch):
  15. trainer = MagicMock()
  16. trainer.load_checkpoint.return_value = {
  17. 'train_iterator': {
  18. 'epoch': epoch,
  19. 'iterations_in_epoch': iterations_in_epoch,
  20. 'shuffle': False,
  21. },
  22. }
  23. trainer.get_num_updates.return_value = num_updates
  24. return trainer
  25. def mock_dict():
  26. d = MagicMock()
  27. d.pad.return_value = 1
  28. d.eos.return_value = 2
  29. d.unk.return_value = 3
  30. return d
  31. def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
  32. tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
  33. tokens_ds = data.TokenBlockDataset(
  34. tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
  35. )
  36. trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
  37. dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
  38. epoch_itr = data.EpochBatchIterator(
  39. dataset=dataset,
  40. collate_fn=dataset.collater,
  41. batch_sampler=[[i] for i in range(epoch_size)],
  42. )
  43. return trainer, epoch_itr
  44. class TestLoadCheckpoint(unittest.TestCase):
  45. def setUp(self):
  46. self.args_mock = MagicMock()
  47. self.args_mock.optimizer_overrides = '{}'
  48. self.patches = {
  49. 'os.makedirs': MagicMock(),
  50. 'os.path.join': MagicMock(),
  51. 'os.path.isfile': MagicMock(return_value=True),
  52. 'os.path.isabs': MagicMock(return_value=False),
  53. }
  54. self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
  55. [p.start() for p in self.applied_patches]
  56. def test_load_partial_checkpoint(self):
  57. with contextlib.redirect_stdout(StringIO()):
  58. trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
  59. train.load_checkpoint(self.args_mock, trainer, epoch_itr)
  60. self.assertEqual(epoch_itr.epoch, 2)
  61. self.assertEqual(epoch_itr.iterations_in_epoch, 50)
  62. itr = epoch_itr.next_epoch_itr(shuffle=False)
  63. self.assertEqual(epoch_itr.epoch, 2)
  64. self.assertEqual(epoch_itr.iterations_in_epoch, 50)
  65. self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
  66. self.assertEqual(epoch_itr.iterations_in_epoch, 51)
  67. def test_load_full_checkpoint(self):
  68. with contextlib.redirect_stdout(StringIO()):
  69. trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
  70. train.load_checkpoint(self.args_mock, trainer, epoch_itr)
  71. itr = epoch_itr.next_epoch_itr(shuffle=False)
  72. self.assertEqual(epoch_itr.epoch, 3)
  73. self.assertEqual(epoch_itr.iterations_in_epoch, 0)
  74. self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
  75. def test_load_no_checkpoint(self):
  76. with contextlib.redirect_stdout(StringIO()):
  77. trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
  78. self.patches['os.path.isfile'].return_value = False
  79. train.load_checkpoint(self.args_mock, trainer, epoch_itr)
  80. itr = epoch_itr.next_epoch_itr(shuffle=False)
  81. self.assertEqual(epoch_itr.epoch, 1)
  82. self.assertEqual(epoch_itr.iterations_in_epoch, 0)
  83. self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
  84. def tearDown(self):
  85. patch.stopall()
  86. if __name__ == '__main__':
  87. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...