Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
  1. import math
  2. import warnings
  3. from copy import deepcopy
  4. from typing import Union
  5. import torch
  6. from torch import nn
  7. from super_gradients.training import utils as core_utils
  8. from super_gradients.training.models import SgModule
  9. from super_gradients.training.models.kd_modules.kd_module import KDModule
  10. def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
  11. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  12. for k, v in b.__dict__.items():
  13. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  14. continue
  15. else:
  16. setattr(a, k, v)
  17. class ModelEMA:
  18. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  19. Keep a moving average of everything in the model state_dict (parameters and buffers).
  20. This is intended to allow functionality like
  21. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  22. A smoothed version of the weights is necessary for some training schemes to perform well.
  23. This class is sensitive where it is initialized in the sequence of model init,
  24. GPU assignment and distributed training wrappers.
  25. """
  26. def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
  27. """
  28. Init the EMA
  29. :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
  30. IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
  31. AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
  32. :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
  33. until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
  34. :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
  35. its final value. beta=15 is ~40% of the training process.
  36. """
  37. # Create EMA
  38. self.ema = deepcopy(model)
  39. self.ema.eval()
  40. if exp_activation:
  41. self.decay_function = lambda x: decay * (1 - math.exp(-x * beta)) # decay exponential ramp (to help early epochs)
  42. else:
  43. self.decay_function = lambda x: decay # always return the same decay factor
  44. """"
  45. we hold a list of model attributes (not wights and biases) which we would like to include in each
  46. attribute update or exclude from each update. a SgModule declare these attribute using
  47. get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
  48. all non-private (not starting with '_') attributes will be updated (and only them).
  49. """
  50. if isinstance(model.module, SgModule):
  51. self.include_attributes = model.module.get_include_attributes()
  52. self.exclude_attributes = model.module.get_exclude_attributes()
  53. else:
  54. warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be "
  55. "included in EMA")
  56. self.include_attributes = []
  57. self.exclude_attributes = []
  58. for p in self.ema.module.parameters():
  59. p.requires_grad_(False)
  60. def update(self, model, training_percent: float):
  61. """
  62. Update the state of the EMA model.
  63. :param model: current training model
  64. :param training_percent: the percentage of the training process [0,1]. i.e 0.4 means 40% of the training have passed
  65. """
  66. # Update EMA parameters
  67. with torch.no_grad():
  68. decay = self.decay_function(training_percent)
  69. for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()):
  70. if ema_v.dtype.is_floating_point:
  71. ema_v.copy_(ema_v * decay + (1. - decay) * model_v.detach())
  72. def update_attr(self, model):
  73. """
  74. This function updates model attributes (not weight and biases) from original model to the ema model.
  75. attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
  76. model operation and need to be updated.
  77. If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
  78. attributes will be updated (and only them).
  79. :param model: the source model
  80. """
  81. copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
  82. class KDModelEMA(ModelEMA):
  83. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  84. Keep a moving average of everything in the model state_dict (parameters and buffers).
  85. This is intended to allow functionality like
  86. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  87. A smoothed version of the weights is necessary for some training schemes to perform well.
  88. This class is sensitive where it is initialized in the sequence of model init,
  89. GPU assignment and distributed training wrappers.
  90. """
  91. def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
  92. """
  93. Init the EMA
  94. :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
  95. IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
  96. AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
  97. :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
  98. until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
  99. :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
  100. its final value. beta=15 is ~40% of the training process.
  101. """
  102. # Only work on the student (we don't want to update and to have a duplicate of the teacher)
  103. super().__init__(model=core_utils.WrappedModel(kd_model.module.student),
  104. decay=decay,
  105. beta=beta,
  106. exp_activation=exp_activation)
  107. # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
  108. # with already the instantiated teacher, to have the final KD EMA
  109. self.ema = core_utils.WrappedModel(KDModule(arch_params=kd_model.module.arch_params,
  110. student=self.ema.module,
  111. teacher=kd_model.module.teacher,
  112. run_teacher_on_eval=kd_model.module.run_teacher_on_eval))
Discard
Tip!

Press p or to see the previous file or, n or to see the next file