Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#680 Optimizer readme

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-605-add_optimizers_tuto
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. from super_gradients.modules import Residual, SkipConnection, BackboneInternalSkipConnection, HeadInternalSkipConnection, CrossModelSkipConnection
  2. try:
  3. from pytorch_quantization import nn as quant_nn
  4. from super_gradients.training.utils.quantization.core import SGQuantMixin
  5. from super_gradients.training.utils.quantization.selective_quantization_utils import register_quantized_module
  6. _imported_pytorch_quantization_failure = None
  7. except (ImportError, NameError, ModuleNotFoundError) as import_err:
  8. _imported_pytorch_quantization_failure = import_err
  9. @register_quantized_module(float_source=Residual)
  10. class QuantResidual(SGQuantMixin):
  11. """
  12. This is a placeholder module used by the quantization engine only.
  13. The module is to be used as a quantized substitute to a residual skip connection within a single block.
  14. """
  15. if _imported_pytorch_quantization_failure is not None:
  16. raise _imported_pytorch_quantization_failure
  17. @classmethod
  18. def from_float(cls, float_instance: Residual, **kwargs):
  19. return quant_nn.TensorQuantizer(kwargs.get("quant_desc_input"))
  20. @register_quantized_module(float_source=SkipConnection)
  21. class QuantSkipConnection(SGQuantMixin):
  22. """
  23. This is a placeholder module used by the quantization engine only.
  24. The module is to be used as a quantized substitute to a skip connection between blocks.
  25. """
  26. if _imported_pytorch_quantization_failure is not None:
  27. raise _imported_pytorch_quantization_failure
  28. @classmethod
  29. def from_float(cls, float_instance: SkipConnection, **kwargs):
  30. return quant_nn.TensorQuantizer(kwargs.get("quant_desc_input"))
  31. @register_quantized_module(float_source=BackboneInternalSkipConnection)
  32. class QuantBackboneInternalSkipConnection(QuantSkipConnection):
  33. """
  34. This is a placeholder module used by the quantization engine only.
  35. The module is to be used as a quantized substitute to a skip connection between blocks inside the backbone.
  36. """
  37. @register_quantized_module(float_source=HeadInternalSkipConnection)
  38. class QuantHeadInternalSkipConnection(QuantSkipConnection):
  39. """
  40. This is a placeholder module used by the quantization engine only.
  41. The module is to be used as a quantized substitute to a skip connection between blocks inside the head.
  42. """
  43. @register_quantized_module(float_source=CrossModelSkipConnection)
  44. class QuantCrossModelSkipConnection(QuantSkipConnection):
  45. """
  46. This is a placeholder module used by the quantization engine only.
  47. The module is to be used as a quantized substitute to a skip connection between backbone and the head.
  48. """
Discard
Tip!

Press p or to see the previous file or, n or to see the next file