Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  1. from super_gradients.training.models.sg_module import SgModule
  2. from collections import namedtuple
  3. import torch
  4. from super_gradients.training.utils.utils import HpmStruct
  5. from super_gradients.training.utils import get_param
  6. KDOutput = namedtuple('KDOutput', 'student_output teacher_output')
  7. class KDModule(SgModule):
  8. """
  9. KDModule
  10. class implementing Knowledge Distillation logic as an SgModule
  11. attributes:
  12. student: SgModule - the student model
  13. teacher: torch.nn.Module- the teacher model
  14. run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)
  15. arch_params: HpmStruct- Architecture H.P.
  16. Additionally, by passing teacher_input_adapter (torch.nn.Module) one can modify the teacher net s.t
  17. teacher = torch.nn.Sequential(teacher_input_adapter, teacher). This is useful when teacher net expects a
  18. different input format from the student (for example different normalization).
  19. """
  20. def __init__(self, arch_params: HpmStruct, student: SgModule, teacher: torch.nn.Module, run_teacher_on_eval=False):
  21. super(KDModule, self).__init__()
  22. self.arch_params = arch_params
  23. self.student = student
  24. self.teacher = teacher
  25. self.teacher_input_adapter = get_param(self.arch_params, "teacher_input_adapter")
  26. self.run_teacher_on_eval = run_teacher_on_eval
  27. self._freeze_teacher()
  28. # WHEN CREATING A MODULE SELF.TRAIN() ISN'T CALLED AND SO THE TEACHER MUST BE MOVED TO EVAL MODE EXPLICITLY
  29. if self.run_teacher_on_eval:
  30. self.teacher.eval()
  31. def _freeze_teacher(self):
  32. for p in self.teacher.parameters():
  33. p.requires_grad = False
  34. if self.teacher_input_adapter is not None:
  35. for p in self.teacher_input_adapter.parameters():
  36. p.requires_grad = False
  37. self.teacher_input_adapter.eval()
  38. def train(self, mode=True):
  39. self.student.train(mode)
  40. if not self.run_teacher_on_eval:
  41. self.teacher.train(mode)
  42. def eval(self):
  43. self.student.eval()
  44. self.teacher.eval()
  45. def forward(self, x):
  46. if self.teacher_input_adapter is not None:
  47. return KDOutput(student_output=self.student(x),
  48. teacher_output=self.teacher(self.teacher_input_adapter(x)))
  49. else:
  50. return KDOutput(student_output=self.student(x),
  51. teacher_output=self.teacher(x))
  52. def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
  53. return self.student.initialize_param_groups(lr, training_params)
  54. def update_param_groups(self, param_groups: list, lr: float, epoch: int, iter: int, training_params: HpmStruct,
  55. total_batch: int) -> list:
  56. return self.student.update_param_groups(param_groups, lr, epoch, iter, training_params, total_batch)
  57. def replace_head(self, **kwargs):
  58. self.student.replace_head(**kwargs)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file