Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_average_checkpoints.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import collections
  8. import os
  9. import tempfile
  10. import unittest
  11. import numpy as np
  12. import torch
  13. from scripts.average_checkpoints import average_checkpoints
  14. class TestAverageCheckpoints(unittest.TestCase):
  15. def test_average_checkpoints(self):
  16. params_0 = collections.OrderedDict(
  17. [
  18. ('a', torch.DoubleTensor([100.0])),
  19. ('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
  20. ('c', torch.IntTensor([7, 8, 9])),
  21. ]
  22. )
  23. params_1 = collections.OrderedDict(
  24. [
  25. ('a', torch.DoubleTensor([1.0])),
  26. ('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
  27. ('c', torch.IntTensor([2, 2, 2])),
  28. ]
  29. )
  30. params_avg = collections.OrderedDict(
  31. [
  32. ('a', torch.DoubleTensor([50.5])),
  33. ('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
  34. # We expect truncation for integer division
  35. ('c', torch.IntTensor([4, 5, 5])),
  36. ]
  37. )
  38. fd_0, path_0 = tempfile.mkstemp()
  39. fd_1, path_1 = tempfile.mkstemp()
  40. torch.save(collections.OrderedDict([('model', params_0)]), path_0)
  41. torch.save(collections.OrderedDict([('model', params_1)]), path_1)
  42. output = average_checkpoints([path_0, path_1])['model']
  43. os.close(fd_0)
  44. os.remove(path_0)
  45. os.close(fd_1)
  46. os.remove(path_1)
  47. for (k_expected, v_expected), (k_out, v_out) in zip(
  48. params_avg.items(), output.items()):
  49. self.assertEqual(
  50. k_expected, k_out, 'Key mismatch - expected {} but found {}. '
  51. '(Expected list of keys: {} vs actual list of keys: {})'.format(
  52. k_expected, k_out, params_avg.keys(), output.keys()
  53. )
  54. )
  55. np.testing.assert_allclose(
  56. v_expected.numpy(),
  57. v_out.numpy(),
  58. err_msg='Tensor value mismatch for key {}'.format(k_expected)
  59. )
  60. if __name__ == '__main__':
  61. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...