Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#970 Update YoloNASQuickstart.md

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_readme_yolonas_snippets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. from typing import Tuple
  2. import torch
  3. from torch import Tensor, nn
  4. from super_gradients.common.object_names import Losses
  5. from super_gradients.common.registry.registry import register_loss
  6. @register_loss(Losses.DEKR_LOSS)
  7. class DEKRLoss(nn.Module):
  8. """
  9. Implementation of the loss function from the "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression"
  10. paper (https://arxiv.org/abs/2104.02300)
  11. This loss should be used in conjunction with DEKRTargetsGenerator.
  12. """
  13. def __init__(self, heatmap_loss_factor: float = 1.0, offset_loss_factor: float = 0.1, heatmap_loss: str = "mse"):
  14. """
  15. Instantiate the DEKR loss function. It is two-component loss function, consisting of a heatmap (MSE) loss and an offset (Smooth L1) losses.
  16. The total loss is the sum of the two individual losses, weighted by the corresponding factors.
  17. :param heatmap_loss_factor: Weighting factor for heatmap loss
  18. :param offset_loss_factor: Weighting factor for offset loss
  19. :param heatmap_loss: Type of heatmap loss to use. Can be "mse" (Used in DEKR paper) or "qfl" (Quality Focal Loss).
  20. We use QFL in our recipe as it produces better results.
  21. """
  22. super().__init__()
  23. self.heatmap_loss_factor = float(heatmap_loss_factor)
  24. self.offset_loss_factor = float(offset_loss_factor)
  25. self.heatmap_loss = {"mse": self.heatmap_mse_loss, "qfl": self.heatmap_qfl_loss}[heatmap_loss]
  26. @property
  27. def component_names(self):
  28. """
  29. Names of individual loss components for logging during training.
  30. """
  31. return ["heatmap", "offset", "total"]
  32. def forward(self, predictions: Tuple[Tensor, Tensor], targets: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
  33. """
  34. :param predictions: Tuple of (heatmap, offset) predictions.
  35. heatmap is of shape (B, NumJoints + 1, H, W)
  36. offset is of shape (B, NumJoints * 2, H, W)
  37. :param targets: Tuple of (heatmap, mask, offset, offset_weight).
  38. heatmap is of shape (B, NumJoints + 1, H, W)
  39. mask is of shape (B, NumJoints + 1, H, W)
  40. offset is of shape (B, NumJoints * 2, H, W)
  41. offset_weight is of shape (B, NumJoints * 2, H, W)
  42. :return: Tuple of (loss, loss_components)
  43. loss is a scalar tensor with the total loss
  44. loss_components is a tensor of shape (3,) containing the individual loss components for logging (detached from the graph)
  45. """
  46. pred_heatmap, pred_offset = predictions
  47. gt_heatmap, mask, gt_offset, offset_weight = targets
  48. heatmap_loss = self.heatmap_loss(pred_heatmap, gt_heatmap, mask) * self.heatmap_loss_factor
  49. offset_loss = self.offset_loss(pred_offset, gt_offset, offset_weight) * self.offset_loss_factor
  50. loss = heatmap_loss + offset_loss
  51. components = torch.cat(
  52. (
  53. heatmap_loss.unsqueeze(0),
  54. offset_loss.unsqueeze(0),
  55. loss.unsqueeze(0),
  56. )
  57. ).detach()
  58. return loss, components
  59. def heatmap_mse_loss(self, pred_heatmap, true_heatmap, mask):
  60. loss = torch.nn.functional.mse_loss(pred_heatmap, true_heatmap, reduction="none") * mask
  61. loss = loss.mean()
  62. return loss
  63. def heatmap_qfl_loss(self, pred_heatmap, true_heatmap, mask):
  64. scale_factor = (true_heatmap - pred_heatmap.sigmoid()).abs().pow(2)
  65. loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_heatmap, true_heatmap, reduction="none") * scale_factor
  66. loss = loss.mean()
  67. return loss
  68. def offset_loss(self, pred_offsets, true_offsets, weights):
  69. num_pos = torch.nonzero(weights > 0).size()[0]
  70. loss = torch.nn.functional.smooth_l1_loss(pred_offsets, true_offsets, reduction="none", beta=1.0 / 9) * weights
  71. if num_pos == 0:
  72. num_pos = 1.0
  73. loss = loss.sum() / num_pos
  74. return loss
Discard
Tip!

Press p or to see the previous file or, n or to see the next file