Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

all_architectures_test.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  1. import unittest
  2. from super_gradients.training.models.all_architectures import ARCHITECTURES
  3. from super_gradients.training.models.sg_module import SgModule
  4. from super_gradients.training.utils.utils import HpmStruct
  5. import torch
  6. class AllArchitecturesTest(unittest.TestCase):
  7. def setUp(self):
  8. # contains all arch_params needed for initialization of all architectures
  9. self.all_arch_params = HpmStruct(**{'num_classes': 10,
  10. 'width_mult': 1,
  11. 'threshold': 1,
  12. 'sml_net': torch.nn.Identity(),
  13. 'big_net': torch.nn.Identity(),
  14. 'dropout': 0})
  15. def test_architecture_is_sg_module(self):
  16. """
  17. Validate all models from all_architectures.py are SgModule
  18. """
  19. for arch_name in ARCHITECTURES:
  20. # skip custom constructors to keep all_arch_params as general as a possible
  21. if 'custom' in arch_name.lower() or 'nas' in arch_name.lower():
  22. continue
  23. self.assertTrue(isinstance(ARCHITECTURES[arch_name](arch_params=self.all_arch_params), SgModule))
  24. if __name__ == '__main__':
  25. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...