Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  1. from typing import Union
  2. from torch import nn
  3. from super_gradients.training.utils.utils import HpmStruct
  4. class SgModule(nn.Module):
  5. def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
  6. """
  7. :return: list of dictionaries containing the key 'named_params' with a list of named params
  8. """
  9. return [{"named_params": self.named_parameters()}]
  10. def update_param_groups(
  11. self, param_groups: list, lr: float, epoch: int, iter: int, training_params: HpmStruct, total_batch: int
  12. ) -> list:
  13. """
  14. :param param_groups: list of dictionaries containing the params
  15. :return: list of dictionaries containing the params
  16. """
  17. for param_group in param_groups:
  18. param_group["lr"] = lr
  19. return param_groups
  20. def get_include_attributes(self) -> list:
  21. """
  22. This function is used by the EMA. When updating the EMA model, some attributes of the main model (used in training)
  23. are updated to the EMA model along with the model weights.
  24. By default, all attributes are updated except for private attributes (starting with '_')
  25. You can either set include_attributes or exclude_attributes. By returning a non empty list from this function,
  26. you override the default behaviour and only attributes named in this list will be updated.
  27. Note: This will also override the get_exclude_attributes list.
  28. :return: list of attributes to update from main model to EMA model
  29. """
  30. return []
  31. def get_exclude_attributes(self) -> list:
  32. """
  33. This function is used by the EMA. When updating the EMA model, some attributes of the main model (used in training)
  34. are updated to the EMA model along with the model weights.
  35. By default, all attributes are updated except for private attributes (starting with '_')
  36. You can either set include_attributes or exclude_attributes. By returning a non empty list from this function,
  37. you override the default behaviour and attributes named in this list will also be excluded from update.
  38. Note: if get_include_attributes is not empty, it will override this list.
  39. :return: list of attributes to not update from main model to EMA mode
  40. """
  41. return []
  42. def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwargs):
  43. """
  44. Prepare the model to be converted to ONNX or other frameworks.
  45. Typically, this function will freeze the size of layers which is otherwise flexible, replace some modules
  46. with convertible substitutes and remove all auxiliary or training related parts.
  47. :param input_size: [H,W]
  48. """
  49. def replace_head(self, **kwargs):
  50. """
  51. Replace final layer for pretrained models. Since this varies between architectures, we leave it to the inheriting
  52. class to implement.
  53. """
  54. raise NotImplementedError
Discard
Tip!

Press p or to see the previous file or, n or to see the next file