Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

load_checkpoint_from_direct_path_test.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  1. import shutil
  2. import tempfile
  3. import unittest
  4. import os
  5. from super_gradients.training import SgModel
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from super_gradients.training.sg_model.sg_model import StrictLoad
  10. class Net(nn.Module):
  11. def __init__(self):
  12. super(Net, self).__init__()
  13. self.conv1 = nn.Conv2d(3, 6, 3)
  14. self.pool = nn.MaxPool2d(2, 2)
  15. self.conv2 = nn.Conv2d(6, 16, 3)
  16. self.fc1 = nn.Linear(16 * 3 * 3, 120)
  17. self.fc2 = nn.Linear(120, 84)
  18. self.fc3 = nn.Linear(84, 10)
  19. def forward(self, x):
  20. x = self.pool(F.relu(self.conv1(x)))
  21. x = self.pool(F.relu(self.conv2(x)))
  22. x = x.view(-1, 16 * 3 * 3)
  23. x = F.relu(self.fc1(x))
  24. x = F.relu(self.fc2(x))
  25. x = self.fc3(x)
  26. return x
  27. class LoadCheckpointFromDirectPathTest(unittest.TestCase):
  28. @classmethod
  29. def setUpClass(cls):
  30. cls.temp_working_file_dir = tempfile.TemporaryDirectory(prefix='load_checkpoint_test').name
  31. if not os.path.isdir(cls.temp_working_file_dir):
  32. os.mkdir(cls.temp_working_file_dir)
  33. cls.checkpoint_path = cls.temp_working_file_dir + '/load_checkpoint_test.pth'
  34. # Setup the model
  35. cls.original_torch_net = Net()
  36. # Save the model's checkpoint
  37. torch.save(cls.original_torch_net.state_dict(), cls.checkpoint_path)
  38. @classmethod
  39. def tearDownClass(cls):
  40. if os.path.isdir(cls.temp_working_file_dir):
  41. shutil.rmtree(cls.temp_working_file_dir)
  42. def test_external_checkpoint_loaded_correctly(self):
  43. # Define Model
  44. new_torch_net = Net()
  45. # Make sure we initialized a model with different weights
  46. assert not self.check_models_have_same_weights(new_torch_net, self.original_torch_net)
  47. # Build the SgModel and load the checkpoint
  48. model = SgModel("load_checkpoint_test", model_checkpoints_location='local')
  49. model.build_model(new_torch_net, arch_params={'num_classes': 10},
  50. external_checkpoint_path=self.checkpoint_path, load_checkpoint=True,
  51. strict_load=StrictLoad.NO_KEY_MATCHING)
  52. # Assert the weights were loaded correctly
  53. assert self.check_models_have_same_weights(model.net, self.original_torch_net)
  54. def check_models_have_same_weights(self, model_1, model_2):
  55. model_1, model_2 = model_1.to('cpu'), model_2.to('cpu')
  56. models_differ = 0
  57. for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
  58. if torch.equal(key_item_1[1], key_item_2[1]):
  59. pass
  60. else:
  61. models_differ += 1
  62. if (key_item_1[0] == key_item_2[0]):
  63. print(f'Layer names match but layers have different weights for layers: {key_item_1[0]}')
  64. if models_differ == 0:
  65. return True
  66. else:
  67. return False
  68. if __name__ == '__main__':
  69. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...