Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_registry.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  1. import unittest
  2. from typing import List
  3. from super_gradients.common.registry.registry import create_register_decorator
  4. from super_gradients.common.factories.base_factory import BaseFactory, UnknownTypeException
  5. class RegistryTest(unittest.TestCase):
  6. def setUp(self) -> None:
  7. # We do all the registration in `setUp` to avoid having registration ran on import
  8. _DUMMY_REGISTRY = {}
  9. register_class = create_register_decorator(registry=_DUMMY_REGISTRY)
  10. @register_class("good_object_name")
  11. class Class1:
  12. def __init__(self, values: List[float]):
  13. self.values = values
  14. @register_class(deprecated_name="deprecated_object_name")
  15. class Class2:
  16. def __init__(self, values: List[float]):
  17. self.values = values
  18. self.Class1 = Class1 # Save classes, not instances
  19. self.Class2 = Class2
  20. self.factory = BaseFactory(type_dict=_DUMMY_REGISTRY)
  21. def test_instantiate_from_name(self):
  22. instance = self.factory.get({"good_object_name": {"values": [1.0, 2.0]}})
  23. self.assertIsInstance(instance, self.Class1)
  24. def test_instantiate_from_classname_when_name_set(self):
  25. with self.assertRaises(UnknownTypeException):
  26. self.factory.get({"Class1": {"values": [1.0, 2.0]}})
  27. def test_instantiate_from_classname_when_no_name_set(self):
  28. instance = self.factory.get({"Class2": {"values": [1.0, 2.0]}})
  29. self.assertIsInstance(instance, self.Class2)
  30. def test_instantiate_from_deprecated_name(self):
  31. with self.assertWarns(DeprecationWarning):
  32. instance = self.factory.get({"deprecated_object_name": {"values": [1.0, 2.0]}})
  33. self.assertIsInstance(instance, self.Class2)
  34. if __name__ == "__main__":
  35. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...