Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

repvgg_unit_test.py 5.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  1. import unittest
  2. from super_gradients.training.models.all_architectures import ARCHITECTURES
  3. from super_gradients.training.models.classification_models.repvgg import RepVggA1
  4. from super_gradients.training.utils.utils import HpmStruct
  5. import torch
  6. import copy
  7. class BackboneBasedModel(torch.nn.Module):
  8. """
  9. Auxiliary model which will use repvgg as backbone
  10. """
  11. def __init__(self, backbone, backbone_output_channel, num_classes=1000):
  12. super(BackboneBasedModel, self).__init__()
  13. self.backbone = backbone
  14. self.conv = torch.nn.Conv2d(in_channels=backbone_output_channel, out_channels=backbone_output_channel, kernel_size=1,
  15. stride=1, padding=0)
  16. self.bn = torch.nn.BatchNorm2d(backbone_output_channel) # Adding a bn layer that should NOT be fused
  17. self.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1)
  18. self.linear = torch.nn.Linear(backbone_output_channel, num_classes)
  19. def forward(self, x):
  20. x = self.backbone(x)
  21. x = self.conv(x)
  22. x = self.bn(x)
  23. x = self.avgpool(x)
  24. x = x.view(x.size(0), -1)
  25. return self.linear(x)
  26. def prep_model_for_conversion(self):
  27. if hasattr(self.backbone, 'prep_model_for_conversion'):
  28. self.backbone.prep_model_for_conversion()
  29. class TestRepVgg(unittest.TestCase):
  30. def setUp(self):
  31. # contains all arch_params needed for initialization of all architectures
  32. self.all_arch_params = HpmStruct(**{'num_classes': 10,
  33. 'width_mult': 1,
  34. 'build_residual_branches': True})
  35. self.backbone_arch_params = copy.deepcopy(self.all_arch_params)
  36. self.backbone_arch_params.override(backbone_mode=True)
  37. def test_deployment_architecture(self):
  38. """
  39. Validate all models that has a deployment mode are in fact different after deployment
  40. """
  41. image_size = 224
  42. in_channels = 3
  43. for arch_name in ARCHITECTURES:
  44. # skip custom constructors to keep all_arch_params as general as a possible
  45. if 'repvgg' not in arch_name or 'custom' in arch_name:
  46. continue
  47. model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params)
  48. self.assertTrue(hasattr(model.stem, 'branch_3x3')) # check single layer for training mode
  49. self.assertTrue(model.build_residual_branches)
  50. training_mode_sd = model.state_dict()
  51. for module in training_mode_sd:
  52. self.assertFalse('reparam' in module) # deployment block included in training mode
  53. test_input = torch.ones((1, in_channels, image_size, image_size))
  54. model.eval()
  55. training_mode_output = model(test_input)
  56. model.prep_model_for_conversion()
  57. self.assertTrue(hasattr(model.stem, 'rbr_reparam')) # check single layer for training mode
  58. self.assertFalse(model.build_residual_branches)
  59. deployment_mode_sd = model.state_dict()
  60. for module in deployment_mode_sd:
  61. self.assertFalse('running_mean' in module) # BN were not fused
  62. self.assertFalse('branch' in module) # branches were not joined
  63. deployment_mode_output = model(test_input)
  64. # difference is of very low magnitude
  65. self.assertFalse(False in torch.isclose(training_mode_output, deployment_mode_output, atol=1e-5))
  66. def test_backbone_mode(self):
  67. """
  68. Validate repvgg models (A1) as backbone.
  69. """
  70. image_size = 224
  71. in_channels = 3
  72. test_input = torch.rand((1, in_channels, image_size, image_size))
  73. backbone_model = RepVggA1(self.backbone_arch_params)
  74. model = BackboneBasedModel(backbone_model, backbone_output_channel=1280,
  75. num_classes=self.backbone_arch_params.num_classes)
  76. backbone_model.eval()
  77. model.eval()
  78. backbone_training_mode_output = backbone_model(test_input)
  79. model_training_mode_output = model(test_input)
  80. # check that the linear head was dropped
  81. self.assertFalse(backbone_training_mode_output.shape[1] == self.backbone_arch_params.num_classes)
  82. training_mode_sd = model.state_dict()
  83. for module in training_mode_sd:
  84. self.assertFalse('reparam' in module) # deployment block included in training mode
  85. model.prep_model_for_conversion()
  86. deployment_mode_sd_list = list(model.state_dict().keys())
  87. self.assertTrue('bn.running_mean' in deployment_mode_sd_list) # Verify non backbone batch norm wasn't fused
  88. for module in deployment_mode_sd_list:
  89. self.assertFalse('running_mean' in module and module.startswith('backbone')) # BN were not fused
  90. self.assertFalse('branch' in module and module.startswith('backbone')) # branches were not joined
  91. model_deployment_mode_output = model(test_input)
  92. self.assertFalse(False in torch.isclose(model_deployment_mode_output, model_training_mode_output, atol=1e-5))
  93. if __name__ == '__main__':
  94. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...