Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#869 Add DagsHub Logger to Super Gradients

Merged
Ghost merged 1 commits into Deci-AI:master from timho102003:dagshub_logger
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. from typing import Union
  2. import torch
  3. import torch.nn as nn
  4. from super_gradients.training.models.kd_modules.kd_module import KDOutput
  5. class SegKDLoss(nn.Module):
  6. """
  7. Wrapper loss for semantic segmentation KD.
  8. This loss includes two loss components, `ce_loss` i.e CrossEntropyLoss, and `kd_loss` i.e
  9. `ChannelWiseKnowledgeDistillationLoss`.
  10. """
  11. def __init__(self, kd_loss: nn.Module, ce_loss: nn.Module, weights: Union[tuple, list], kd_loss_weights: Union[tuple, list]):
  12. """
  13. :param kd_loss: knowledge distillation criteria, such as, ChannelWiseKnowledgeDistillationLoss.
  14. This loss should except as input a triplet of the predictions from the model with shape [B, C, H, W],
  15. the teacher model predictions with shape [B, C, H, W] and the target labels with shape [B, H, W].
  16. :param ce_loss: classification criteria, such as, CE, OHEM, MaskAttention, SL1, etc.
  17. This loss should except as input the predictions from the model with shape [B, C, H, W], and the target labels
  18. with shape [B, H, W].
  19. :param weights: lambda weights to apply upon each prediction map heads.
  20. :param kd_loss_weights: lambda weights to apply upon each criterion. 2 values are excepted as follows,
  21. [ce_loss_weight, kd_loss_weight].
  22. """
  23. super().__init__()
  24. self.kd_loss_weights = kd_loss_weights
  25. self.weights = weights
  26. self.kd_loss = kd_loss
  27. self.ce_loss = ce_loss
  28. self._validate_arguments()
  29. def _validate_arguments(self):
  30. # Check num of loss weights
  31. if len(self.kd_loss_weights) != 2:
  32. raise ValueError(f"kd_loss_weights is expected to be an iterable with size 2," f" found: {len(self.kd_loss_weights)}")
  33. def forward(self, preds: KDOutput, target: torch.Tensor):
  34. if not isinstance(preds, KDOutput):
  35. raise RuntimeError(
  36. "Predictions argument for `SegKDLoss` forward method is expected to be a `KDOutput` to"
  37. " include the predictions from both the student and the teacher models."
  38. )
  39. teacher_preds = preds.teacher_output
  40. student_preds = preds.student_output
  41. if isinstance(teacher_preds, torch.Tensor):
  42. teacher_preds = (teacher_preds,)
  43. if isinstance(student_preds, torch.Tensor):
  44. student_preds = (student_preds,)
  45. losses = []
  46. total_loss = 0
  47. # Main and auxiliaries feature maps losses
  48. for i in range(len(self.weights)):
  49. ce_loss = self.ce_loss(student_preds[i], target)
  50. cwd_loss = self.kd_loss(student_preds[i], teacher_preds[i], target)
  51. loss = self.kd_loss_weights[0] * ce_loss + self.kd_loss_weights[1] * cwd_loss
  52. total_loss += self.weights[i] * loss
  53. losses += [ce_loss, cwd_loss]
  54. losses.append(total_loss)
  55. return total_loss, torch.stack(losses, dim=0).detach()
  56. @property
  57. def component_names(self):
  58. """
  59. Component names for logging during training.
  60. These correspond to 2nd item in the tuple returned in self.forward(...).
  61. See super_gradients.Trainer.train() docs for more info.
  62. """
  63. component_names = []
  64. for i in range(len(self.weights)):
  65. component_names += [f"Head-{i}_CE_Loss", f"Head-{i}_KD_Loss"]
  66. component_names.append("Total_Loss")
  67. return component_names
Discard
Tip!

Press p or to see the previous file or, n or to see the next file