Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ppyoloe_unit_test.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. import unittest
  2. import torch
  3. from super_gradients.training import models
  4. from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_e import PPYoloE_X, PPYoloE_S, PPYoloE_M, PPYoloE_L
  5. class TestPPYOLOE(unittest.TestCase):
  6. def _test_ppyoloe_from_name(self, model_name, pretrained_weights):
  7. ppyoloe = models.get(model_name, pretrained_weights=pretrained_weights, num_classes=80 if pretrained_weights is None else None).eval()
  8. dummy_input = torch.randn(1, 3, 640, 480)
  9. with torch.no_grad():
  10. feature_maps = ppyoloe(dummy_input)
  11. self.assertIsNotNone(feature_maps)
  12. def _test_ppyoloe_from_cls(self, model_cls):
  13. ppyoloe = model_cls(arch_params={}).eval()
  14. dummy_input = torch.randn(1, 3, 640, 480)
  15. with torch.no_grad():
  16. feature_maps = ppyoloe(dummy_input)
  17. self.assertIsNotNone(feature_maps)
  18. def test_ppyoloe_s(self):
  19. self._test_ppyoloe_from_name("ppyoloe_s", pretrained_weights="coco")
  20. self._test_ppyoloe_from_cls(PPYoloE_S)
  21. def test_ppyoloe_m(self):
  22. self._test_ppyoloe_from_name("ppyoloe_m", pretrained_weights="coco")
  23. self._test_ppyoloe_from_cls(PPYoloE_M)
  24. def test_ppyoloe_l(self):
  25. self._test_ppyoloe_from_name("ppyoloe_l", pretrained_weights=None)
  26. self._test_ppyoloe_from_cls(PPYoloE_L)
  27. def test_ppyoloe_x(self):
  28. self._test_ppyoloe_from_name("ppyoloe_x", pretrained_weights=None)
  29. self._test_ppyoloe_from_cls(PPYoloE_X)
  30. if __name__ == "__main__":
  31. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...