Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#970 Update YoloNASQuickstart.md

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_readme_yolonas_snippets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  1. from typing import Optional
  2. import torch.nn as nn
  3. import torch
  4. class ChannelWiseKnowledgeDistillationLoss(nn.Module):
  5. """
  6. Implementation of Channel-wise Knowledge distillation loss.
  7. paper: "Channel-wise Knowledge Distillation for Dense Prediction", https://arxiv.org/abs/2011.13256
  8. Official implementation: https://github.com/irfanICMLL/TorchDistiller/tree/main/SemSeg-distill
  9. """
  10. def __init__(self, normalization_mode: str = "channel_wise", temperature: float = 4.0, ignore_index: Optional[int] = None):
  11. """
  12. :param normalization_mode: default is for `channel-wise` normalization as implemented in the original paper,
  13. softmax is applied upon the spatial dimensions. For vanilla normalization, to apply softmax upon the channel
  14. dimension, set this value as `spatial_wise`.
  15. :param temperature: temperature relaxation value applied upon the logits before the normalization. default value
  16. is set to `4.0` as the original implementation.
  17. """
  18. super().__init__()
  19. self.T = temperature
  20. self.ignore_index = ignore_index
  21. self.kl_div = nn.KLDivLoss(reduction="sum" if ignore_index is None else "none")
  22. if normalization_mode not in ["channel_wise", "spatial_wise"]:
  23. raise ValueError(f"Unsupported normalization mode: {normalization_mode}")
  24. self.normalization_mode = normalization_mode
  25. def forward(self, student_preds: torch.Tensor, teacher_preds: torch.Tensor, target: Optional[torch.Tensor] = None):
  26. B, C, H, W = student_preds.size()
  27. # set the normalization axis and the averaging scalar.
  28. norm_axis = -1 if self.normalization_mode == "channel_wise" else 1
  29. averaging_scalar = (B * C) if self.normalization_mode == "channel_wise" else (B * H * W)
  30. # Softmax normalization
  31. softmax_teacher = torch.softmax(teacher_preds.view(B, C, -1) / self.T, dim=norm_axis)
  32. log_softmax_student = torch.log_softmax(student_preds.view(B, C, -1) / self.T, dim=norm_axis)
  33. loss = self.kl_div(log_softmax_student, softmax_teacher)
  34. if self.ignore_index is not None:
  35. valid_mask = target.view(B, -1).ne(self.ignore_index).unsqueeze(1).expand_as(loss)
  36. loss = (loss * valid_mask).sum()
  37. loss = loss * (self.T**2) / averaging_scalar
  38. return loss
Discard
Tip!

Press p or to see the previous file or, n or to see the next file