Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#669 Hotfix/sg 645 regression tests essential fixes

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-645_limit_tests_forward_passes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  1. import unittest
  2. from super_gradients.training.models.all_architectures import ARCHITECTURES
  3. from super_gradients.training.models.classification_models.repvgg import RepVggA1
  4. from super_gradients.training.utils.utils import HpmStruct
  5. import torch
  6. import copy
  7. class BackboneBasedModel(torch.nn.Module):
  8. """
  9. Auxiliary model which will use repvgg as backbone
  10. """
  11. def __init__(self, backbone, backbone_output_channel, num_classes=1000):
  12. super(BackboneBasedModel, self).__init__()
  13. self.backbone = backbone
  14. self.conv = torch.nn.Conv2d(in_channels=backbone_output_channel, out_channels=backbone_output_channel, kernel_size=1, stride=1, padding=0)
  15. self.bn = torch.nn.BatchNorm2d(backbone_output_channel) # Adding a bn layer that should NOT be fused
  16. self.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1)
  17. self.linear = torch.nn.Linear(backbone_output_channel, num_classes)
  18. def forward(self, x):
  19. x = self.backbone(x)
  20. x = self.conv(x)
  21. x = self.bn(x)
  22. x = self.avgpool(x)
  23. x = x.view(x.size(0), -1)
  24. return self.linear(x)
  25. def prep_model_for_conversion(self):
  26. if hasattr(self.backbone, "prep_model_for_conversion"):
  27. self.backbone.prep_model_for_conversion()
  28. class TestRepVgg(unittest.TestCase):
  29. def setUp(self):
  30. # contains all arch_params needed for initialization of all architectures
  31. self.all_arch_params = HpmStruct(**{"num_classes": 10, "width_mult": 1, "build_residual_branches": True})
  32. self.backbone_arch_params = copy.deepcopy(self.all_arch_params)
  33. self.backbone_arch_params.override(backbone_mode=True)
  34. def test_deployment_architecture(self):
  35. """
  36. Validate all models that has a deployment mode are in fact different after deployment
  37. """
  38. image_size = 224
  39. in_channels = 3
  40. for arch_name in ARCHITECTURES:
  41. # skip custom constructors to keep all_arch_params as general as a possible
  42. if "repvgg" not in arch_name or "custom" in arch_name:
  43. continue
  44. model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params)
  45. self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode
  46. self.assertTrue(model.build_residual_branches)
  47. training_mode_sd = model.state_dict()
  48. for module in training_mode_sd:
  49. self.assertFalse("reparam" in module) # deployment block included in training mode
  50. test_input = torch.ones((1, in_channels, image_size, image_size))
  51. model.eval()
  52. training_mode_output = model(test_input)
  53. model.prep_model_for_conversion()
  54. self.assertTrue(hasattr(model.stem, "rbr_reparam")) # check single layer for training mode
  55. self.assertFalse(model.build_residual_branches)
  56. deployment_mode_sd = model.state_dict()
  57. for module in deployment_mode_sd:
  58. self.assertFalse("running_mean" in module) # BN were not fused
  59. self.assertFalse("branch" in module) # branches were not joined
  60. deployment_mode_output = model(test_input)
  61. # difference is of very low magnitude
  62. self.assertFalse(False in torch.isclose(training_mode_output, deployment_mode_output, atol=1e-4))
  63. def test_backbone_mode(self):
  64. """
  65. Validate repvgg models (A1) as backbone.
  66. """
  67. image_size = 224
  68. in_channels = 3
  69. test_input = torch.rand((1, in_channels, image_size, image_size))
  70. backbone_model = RepVggA1(self.backbone_arch_params)
  71. model = BackboneBasedModel(backbone_model, backbone_output_channel=1280, num_classes=self.backbone_arch_params.num_classes)
  72. backbone_model.eval()
  73. model.eval()
  74. backbone_training_mode_output = backbone_model(test_input)
  75. model_training_mode_output = model(test_input)
  76. # check that the linear head was dropped
  77. self.assertFalse(backbone_training_mode_output.shape[1] == self.backbone_arch_params.num_classes)
  78. training_mode_sd = model.state_dict()
  79. for module in training_mode_sd:
  80. self.assertFalse("reparam" in module) # deployment block included in training mode
  81. model.prep_model_for_conversion()
  82. deployment_mode_sd_list = list(model.state_dict().keys())
  83. self.assertTrue("bn.running_mean" in deployment_mode_sd_list) # Verify non backbone batch norm wasn't fused
  84. for module in deployment_mode_sd_list:
  85. self.assertFalse("running_mean" in module and module.startswith("backbone")) # BN were not fused
  86. self.assertFalse("branch" in module and module.startswith("backbone")) # branches were not joined
  87. model_deployment_mode_output = model(test_input)
  88. self.assertFalse(False in torch.isclose(model_deployment_mode_output, model_training_mode_output, atol=1e-5))
  89. if __name__ == "__main__":
  90. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file