Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

all_architectures_test.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. import unittest
  2. from super_gradients.common.registry.registry import ARCHITECTURES
  3. from super_gradients.training.models.sg_module import SgModule
  4. from super_gradients.training.utils.utils import HpmStruct
  5. import torch
  6. class AllArchitecturesTest(unittest.TestCase):
  7. def setUp(self):
  8. # contains all arch_params needed for initialization of all architectures
  9. self.all_arch_params = HpmStruct(
  10. **{
  11. "num_classes": 10,
  12. "width_mult": 1,
  13. "threshold": 1,
  14. "sml_net": torch.nn.Identity(),
  15. "big_net": torch.nn.Identity(),
  16. "dropout": 0,
  17. "build_residual_branches": True,
  18. }
  19. )
  20. def test_architecture_is_sg_module(self):
  21. """
  22. Validate all models from all_architectures.py are SgModule
  23. """
  24. for arch_name in ARCHITECTURES:
  25. # skip custom constructors to keep all_arch_params as general as a possible
  26. if "custom" in arch_name.lower() or "nas" in arch_name.lower() or "kd" in arch_name.lower():
  27. continue
  28. self.assertTrue(isinstance(ARCHITECTURES[arch_name](arch_params=self.all_arch_params), SgModule))
  29. if __name__ == "__main__":
  30. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...