1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import contextlib
- from io import StringIO
- import json
- import os
- import tempfile
- import unittest
- from . import test_binaries
- class TestReproducibility(unittest.TestCase):
- def _test_reproducibility(self, name, extra_flags=None):
- if extra_flags is None:
- extra_flags = []
- with tempfile.TemporaryDirectory(name) as data_dir:
- with contextlib.redirect_stdout(StringIO()):
- test_binaries.create_dummy_data(data_dir)
- test_binaries.preprocess_translation_data(data_dir)
- # train epochs 1 and 2 together
- stdout = StringIO()
- with contextlib.redirect_stdout(stdout):
- test_binaries.train_translation_model(
- data_dir, 'fconv_iwslt_de_en', [
- '--dropout', '0.0',
- '--log-format', 'json',
- '--log-interval', '1',
- '--max-epoch', '3',
- ] + extra_flags,
- )
- stdout = stdout.getvalue()
- train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2])
- # train epoch 2, resuming from previous checkpoint 1
- os.rename(
- os.path.join(data_dir, 'checkpoint1.pt'),
- os.path.join(data_dir, 'checkpoint_last.pt'),
- )
- stdout = StringIO()
- with contextlib.redirect_stdout(stdout):
- test_binaries.train_translation_model(
- data_dir, 'fconv_iwslt_de_en', [
- '--dropout', '0.0',
- '--log-format', 'json',
- '--log-interval', '1',
- '--max-epoch', '3',
- ] + extra_flags,
- )
- stdout = stdout.getvalue()
- train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2])
- def cast(s):
- return round(float(s), 3)
- for k in ['loss', 'ppl', 'num_updates', 'gnorm']:
- self.assertEqual(cast(train_log[k]), cast(train_res_log[k]))
- for k in ['valid_loss', 'valid_ppl', 'num_updates', 'best']:
- self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k]))
- def test_reproducibility(self):
- self._test_reproducibility('test_reproducibility')
- def test_reproducibility_fp16(self):
- self._test_reproducibility('test_reproducibility_fp16', [
- '--fp16',
- '--fp16-init-scale', '4096',
- ])
- def test_reproducibility_memory_efficient_fp16(self):
- self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
- '--memory-efficient-fp16',
- '--fp16-init-scale', '4096',
- ])
- if __name__ == '__main__':
- unittest.main()
|