Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

strictload_enum_test.py 6.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
  1. import os
  2. import shutil
  3. import tempfile
  4. import unittest
  5. from super_gradients.common.object_names import Models
  6. from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
  7. from super_gradients.training import Trainer
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from super_gradients.training import models
  12. from super_gradients.training.sg_trainer.sg_trainer import StrictLoad
  13. from super_gradients.training.utils import HpmStruct
  14. class Net(nn.Module):
  15. def __init__(self):
  16. super(Net, self).__init__()
  17. self.conv1 = nn.Conv2d(3, 6, 3)
  18. self.pool = nn.MaxPool2d(2, 2)
  19. self.conv2 = nn.Conv2d(6, 16, 3)
  20. self.fc1 = nn.Linear(16 * 3 * 3, 120)
  21. self.fc2 = nn.Linear(120, 84)
  22. self.fc3 = nn.Linear(84, 10)
  23. def forward(self, x):
  24. x = self.pool(F.relu(self.conv1(x)))
  25. x = self.pool(F.relu(self.conv2(x)))
  26. x = x.view(-1, 16 * 3 * 3)
  27. x = F.relu(self.fc1(x))
  28. x = F.relu(self.fc2(x))
  29. x = self.fc3(x)
  30. return x
  31. class StrictLoadEnumTest(unittest.TestCase):
  32. @classmethod
  33. def setUpClass(cls):
  34. cls.temp_working_file_dir = tempfile.TemporaryDirectory(prefix="strict_load_test").name
  35. if not os.path.isdir(cls.temp_working_file_dir):
  36. os.mkdir(cls.temp_working_file_dir)
  37. cls.experiment_name = "load_checkpoint_test"
  38. cls.checkpoint_diff_keys_name = "strict_load_test_diff_keys.pth"
  39. cls.checkpoint_diff_keys_path = cls.temp_working_file_dir + "/" + cls.checkpoint_diff_keys_name
  40. # Setup the model
  41. cls.original_torch_model = Net()
  42. # Save the model's state_dict checkpoint with different keys
  43. torch.save(cls.change_state_dict_keys(cls.original_torch_model.state_dict()), cls.checkpoint_diff_keys_path)
  44. # Save the model's state_dict checkpoint in Trainer format
  45. cls.trainer = Trainer("load_checkpoint_test")
  46. # This should be defined when calling `Trainer.train` but we don't do it here so we hardcode it
  47. cls.trainer.checkpoints_dir_path = cls.temp_working_file_dir
  48. cls.trainer.set_net(cls.original_torch_model)
  49. # FIXME: after uniting init and build_model we should remove this
  50. cls.trainer.sg_logger = BaseSGLogger(
  51. "project_name",
  52. "load_checkpoint_test",
  53. "local",
  54. resumed=False,
  55. training_params=HpmStruct(max_epochs=10),
  56. checkpoints_dir_path=cls.trainer.checkpoints_dir_path,
  57. monitor_system=False,
  58. )
  59. cls.trainer._save_checkpoint()
  60. @classmethod
  61. def tearDownClass(cls):
  62. if os.path.isdir(cls.temp_working_file_dir):
  63. shutil.rmtree(cls.temp_working_file_dir)
  64. @classmethod
  65. def change_state_dict_keys(self, state_dict):
  66. new_ckpt_dict = {}
  67. for i, (ckpt_key, ckpt_val) in enumerate(state_dict.items()):
  68. new_ckpt_dict[str(i)] = ckpt_val
  69. return new_ckpt_dict
  70. def check_models_have_same_weights(self, model_1, model_2):
  71. model_1, model_2 = model_1.to("cpu"), model_2.to("cpu")
  72. models_differ = 0
  73. for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
  74. if torch.equal(key_item_1[1], key_item_2[1]):
  75. pass
  76. else:
  77. models_differ += 1
  78. if key_item_1[0] == key_item_2[0]:
  79. print("Mismtach found at", key_item_1[0])
  80. else:
  81. raise Exception
  82. if models_differ == 0:
  83. return True
  84. else:
  85. return False
  86. def test_strict_load_on(self):
  87. # Define Model
  88. model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
  89. pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
  90. # Make sure we initialized a model with different weights
  91. assert not self.check_models_have_same_weights(model, pretrained_model)
  92. pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_on.pth")
  93. torch.save(pretrained_model.state_dict(), pretrained_sd_path)
  94. model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
  95. # Assert the weights were loaded correctly
  96. assert self.check_models_have_same_weights(model, pretrained_model)
  97. def test_strict_load_off(self):
  98. # Define Model
  99. model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
  100. pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
  101. # Make sure we initialized a model with different weights
  102. assert not self.check_models_have_same_weights(model, pretrained_model)
  103. pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_net_strict_load_off.pth")
  104. del pretrained_model.linear
  105. torch.save(pretrained_model.state_dict(), pretrained_sd_path)
  106. with self.assertRaises(RuntimeError):
  107. models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
  108. model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.OFF)
  109. del model.linear
  110. # Assert the weights were loaded correctly
  111. assert self.check_models_have_same_weights(model, pretrained_model)
  112. def test_strict_load_no_key_matching_sg_checkpoint(self):
  113. # Define Model
  114. model = models.get(Models.RESNET18, arch_params={"num_classes": 1000})
  115. pretrained_model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, pretrained_weights="imagenet")
  116. # Make sure we initialized a model with different weights
  117. assert not self.check_models_have_same_weights(model, pretrained_model)
  118. pretrained_sd_path = os.path.join(self.temp_working_file_dir, "pretrained_model_strict_load_soft.pth")
  119. torch.save(self.change_state_dict_keys(pretrained_model.state_dict()), pretrained_sd_path)
  120. with self.assertRaises(RuntimeError):
  121. models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.ON)
  122. model = models.get(Models.RESNET18, arch_params={"num_classes": 1000}, checkpoint_path=pretrained_sd_path, strict_load=StrictLoad.NO_KEY_MATCHING)
  123. # Assert the weights were loaded correctly
  124. assert self.check_models_have_same_weights(model, pretrained_model)
  125. if __name__ == "__main__":
  126. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...