Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

load_checkpoint_test.py 1.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  1. import unittest
  2. import torch.nn.init
  3. from torch import nn
  4. from super_gradients.training.utils.checkpoint_utils import transfer_weights
  5. class LoadCheckpointTest(unittest.TestCase):
  6. def test_transfer_weights(self):
  7. class Foo(nn.Module):
  8. def __init__(self):
  9. super().__init__()
  10. self.fc1 = nn.Linear(10, 10)
  11. self.fc2 = nn.Linear(10, 10)
  12. torch.nn.init.zeros_(self.fc1.weight)
  13. torch.nn.init.zeros_(self.fc2.weight)
  14. class Bar(nn.Module):
  15. def __init__(self):
  16. super().__init__()
  17. self.fc1 = nn.Linear(10, 11)
  18. self.fc2 = nn.Linear(10, 10)
  19. torch.nn.init.ones_(self.fc1.weight)
  20. torch.nn.init.ones_(self.fc2.weight)
  21. foo = Foo()
  22. bar = Bar()
  23. self.assertFalse((foo.fc2.weight == bar.fc2.weight).all())
  24. transfer_weights(foo, bar.state_dict())
  25. self.assertTrue((foo.fc2.weight == bar.fc2.weight).all())
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...