Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

kd_losses.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
  1. from torch.nn.modules.loss import _Loss, KLDivLoss
  2. import torch
  3. from super_gradients.common.decorators.factory_decorator import resolve_param
  4. from super_gradients.common.factories.losses_factory import LossesFactory
  5. from super_gradients.common.object_names import Losses
  6. from super_gradients.common.registry.registry import register_loss
  7. class KDklDivLoss(KLDivLoss):
  8. """KL divergence wrapper for knowledge distillation"""
  9. def __init__(self):
  10. super(KDklDivLoss, self).__init__(reduction="batchmean")
  11. def forward(self, student_output, teacher_output):
  12. return super(KDklDivLoss, self).forward(torch.log_softmax(student_output, dim=1), torch.softmax(teacher_output, dim=1))
  13. @register_loss(name=Losses.KD_LOSS, deprecated_name="kd_loss")
  14. class KDLogitsLoss(_Loss):
  15. """Knowledge distillation loss, wraps the task loss and distillation loss"""
  16. @resolve_param("task_loss_fn", LossesFactory())
  17. def __init__(self, task_loss_fn: _Loss, distillation_loss_fn: _Loss = KDklDivLoss(), distillation_loss_coeff: float = 0.5):
  18. """
  19. :param task_loss_fn: task loss. E.g., CrossEntropyLoss
  20. :param distillation_loss_fn: distillation loss. E.g., KLDivLoss
  21. :param distillation_loss_coeff:
  22. """
  23. super(KDLogitsLoss, self).__init__()
  24. self.task_loss_fn = task_loss_fn
  25. self.distillation_loss_fn = distillation_loss_fn
  26. self.distillation_loss_coeff = distillation_loss_coeff
  27. @property
  28. def component_names(self):
  29. """
  30. Component names for logging during training.
  31. These correspond to 2nd item in the tuple returned in self.forward(...).
  32. See super_gradients.Trainer.train() docs for more info.
  33. """
  34. return ["Loss", "Task Loss", "Distillation Loss"]
  35. def forward(self, kd_module_output, target):
  36. task_loss = self.task_loss_fn(kd_module_output.student_output, target)
  37. if isinstance(task_loss, tuple): # SOME LOSS FUNCTIONS RETURNS LOSS AND LOG_ITEMS
  38. task_loss = task_loss[0]
  39. distillation_loss = self.distillation_loss_fn(kd_module_output.student_output, kd_module_output.teacher_output)
  40. loss = task_loss * (1 - self.distillation_loss_coeff) + distillation_loss * self.distillation_loss_coeff
  41. return loss, torch.cat((loss.unsqueeze(0), task_loss.unsqueeze(0), distillation_loss.unsqueeze(0))).detach()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...