Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_reproducibility.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import contextlib
  8. from io import StringIO
  9. import json
  10. import os
  11. import tempfile
  12. import unittest
  13. from . import test_binaries
  14. class TestReproducibility(unittest.TestCase):
  15. def _test_reproducibility(self, name, extra_flags=None):
  16. if extra_flags is None:
  17. extra_flags = []
  18. with tempfile.TemporaryDirectory(name) as data_dir:
  19. with contextlib.redirect_stdout(StringIO()):
  20. test_binaries.create_dummy_data(data_dir)
  21. test_binaries.preprocess_translation_data(data_dir)
  22. # train epochs 1 and 2 together
  23. stdout = StringIO()
  24. with contextlib.redirect_stdout(stdout):
  25. test_binaries.train_translation_model(
  26. data_dir, 'fconv_iwslt_de_en', [
  27. '--dropout', '0.0',
  28. '--log-format', 'json',
  29. '--log-interval', '1',
  30. '--max-epoch', '3',
  31. ] + extra_flags,
  32. )
  33. stdout = stdout.getvalue()
  34. train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2])
  35. # train epoch 2, resuming from previous checkpoint 1
  36. os.rename(
  37. os.path.join(data_dir, 'checkpoint1.pt'),
  38. os.path.join(data_dir, 'checkpoint_last.pt'),
  39. )
  40. stdout = StringIO()
  41. with contextlib.redirect_stdout(stdout):
  42. test_binaries.train_translation_model(
  43. data_dir, 'fconv_iwslt_de_en', [
  44. '--dropout', '0.0',
  45. '--log-format', 'json',
  46. '--log-interval', '1',
  47. '--max-epoch', '3',
  48. ] + extra_flags,
  49. )
  50. stdout = stdout.getvalue()
  51. train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2])
  52. def cast(s):
  53. return round(float(s), 3)
  54. for k in ['loss', 'ppl', 'num_updates', 'gnorm']:
  55. self.assertEqual(cast(train_log[k]), cast(train_res_log[k]))
  56. for k in ['valid_loss', 'valid_ppl', 'num_updates', 'best']:
  57. self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k]))
  58. def test_reproducibility(self):
  59. self._test_reproducibility('test_reproducibility')
  60. def test_reproducibility_fp16(self):
  61. self._test_reproducibility('test_reproducibility_fp16', [
  62. '--fp16',
  63. '--fp16-init-scale', '4096',
  64. ])
  65. def test_reproducibility_memory_efficient_fp16(self):
  66. self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
  67. '--memory-efficient-fp16',
  68. '--fp16-init-scale', '4096',
  69. ])
  70. if __name__ == '__main__':
  71. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...