Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  1. import inspect
  2. from functools import wraps
  3. from super_gradients.common.factories.base_factory import AbstractFactory
  4. def _assign_tuple(t: tuple, index: int, value):
  5. return tuple([x if i != index else value for i, x in enumerate(t)])
  6. def resolve_param(param_name: str, factory: AbstractFactory):
  7. """
  8. A decorator function which resolves a specific named parameter using a defined Factory
  9. usage:
  10. @resolve_param(my_param_name, MyFactory())
  11. def foo(self, a, my_param_name, b, c)
  12. ...
  13. this will use MyFactory to generate an object from the provided value of my_param_name
  14. """
  15. def inner(func):
  16. @wraps(func)
  17. def wrapper(*args, **kwargs):
  18. if param_name in kwargs:
  19. # handle kwargs
  20. kwargs[param_name] = factory.get(kwargs[param_name])
  21. else:
  22. # handle args
  23. func_args = inspect.getfullargspec(func).args
  24. if param_name in func_args:
  25. index = func_args.index(param_name)
  26. if index < len(args):
  27. new_value = factory.get(args[index])
  28. args = _assign_tuple(args, index, new_value)
  29. return func(*args, **kwargs)
  30. return wrapper
  31. return inner
Discard
Tip!

Press p or to see the previous file or, n or to see the next file