1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import collections
- import os
- import tempfile
- import unittest
- import numpy as np
- import torch
- from scripts.average_checkpoints import average_checkpoints
- class TestAverageCheckpoints(unittest.TestCase):
- def test_average_checkpoints(self):
- params_0 = collections.OrderedDict(
- [
- ('a', torch.DoubleTensor([100.0])),
- ('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
- ('c', torch.IntTensor([7, 8, 9])),
- ]
- )
- params_1 = collections.OrderedDict(
- [
- ('a', torch.DoubleTensor([1.0])),
- ('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
- ('c', torch.IntTensor([2, 2, 2])),
- ]
- )
- params_avg = collections.OrderedDict(
- [
- ('a', torch.DoubleTensor([50.5])),
- ('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
- # We expect truncation for integer division
- ('c', torch.IntTensor([4, 5, 5])),
- ]
- )
- fd_0, path_0 = tempfile.mkstemp()
- fd_1, path_1 = tempfile.mkstemp()
- torch.save(collections.OrderedDict([('model', params_0)]), path_0)
- torch.save(collections.OrderedDict([('model', params_1)]), path_1)
- output = average_checkpoints([path_0, path_1])['model']
- os.close(fd_0)
- os.remove(path_0)
- os.close(fd_1)
- os.remove(path_1)
- for (k_expected, v_expected), (k_out, v_out) in zip(
- params_avg.items(), output.items()):
- self.assertEqual(
- k_expected, k_out, 'Key mismatch - expected {} but found {}. '
- '(Expected list of keys: {} vs actual list of keys: {})'.format(
- k_expected, k_out, params_avg.keys(), output.keys()
- )
- )
- np.testing.assert_allclose(
- v_expected.numpy(),
- v_out.numpy(),
- err_msg='Tensor value mismatch for key {}'.format(k_expected)
- )
- if __name__ == '__main__':
- unittest.main()
|