Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#869 Add DagsHub Logger to Super Gradients

Merged
Ghost merged 1 commits into Deci-AI:master from timho102003:dagshub_logger
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  1. from typing import Tuple
  2. import torch
  3. from torch import Tensor, nn
  4. from super_gradients.common.object_names import Losses
  5. from super_gradients.common.registry.registry import register_loss
  6. @register_loss(Losses.DEKR_LOSS)
  7. class DEKRLoss(nn.Module):
  8. """
  9. Implementation of the loss function from the "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression"
  10. paper (https://arxiv.org/abs/2104.02300)
  11. This loss should be used in conjunction with DEKRTargetsGenerator.
  12. """
  13. def __init__(self, heatmap_loss_factor: float = 1.0, offset_loss_factor: float = 0.1):
  14. """
  15. Instantiate the DEKR loss function. It is two-component loss function, consisting of a heatmap (MSE) loss and an offset (Smooth L1) losses.
  16. The total loss is the sum of the two individual losses, weighted by the corresponding factors.
  17. :param heatmap_loss_factor: Weighting factor for heatmap loss
  18. :param offset_loss_factor: Weighting factor for offset loss
  19. """
  20. super().__init__()
  21. self.heatmap_loss_factor = float(heatmap_loss_factor)
  22. self.offset_loss_factor = float(offset_loss_factor)
  23. @property
  24. def component_names(self):
  25. """
  26. Names of individual loss components for logging during training.
  27. """
  28. return ["heatmap", "offset", "total"]
  29. def forward(self, predictions: Tuple[Tensor, Tensor], targets: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
  30. """
  31. :param predictions: Tuple of (heatmap, offset) predictions.
  32. heatmap is of shape (B, NumJoints + 1, H, W)
  33. offset is of shape (B, NumJoints * 2, H, W)
  34. :param targets: Tuple of (heatmap, mask, offset, offset_weight).
  35. heatmap is of shape (B, NumJoints + 1, H, W)
  36. mask is of shape (B, NumJoints + 1, H, W)
  37. offset is of shape (B, NumJoints * 2, H, W)
  38. offset_weight is of shape (B, NumJoints * 2, H, W)
  39. :return: Tuple of (loss, loss_components)
  40. loss is a scalar tensor with the total loss
  41. loss_components is a tensor of shape (3,) containing the individual loss components for logging (detached from the graph)
  42. """
  43. pred_heatmap, pred_offset = predictions
  44. gt_heatmap, mask, gt_offset, offset_weight = targets
  45. heatmap_loss = self.heatmap_loss(pred_heatmap, gt_heatmap, mask) * self.heatmap_loss_factor
  46. offset_loss = self.offset_loss(pred_offset, gt_offset, offset_weight) * self.offset_loss_factor
  47. loss = heatmap_loss + offset_loss
  48. components = torch.cat(
  49. (
  50. heatmap_loss.unsqueeze(0),
  51. offset_loss.unsqueeze(0),
  52. loss.unsqueeze(0),
  53. )
  54. ).detach()
  55. return loss, components
  56. def heatmap_loss(self, pred_heatmap, true_heatmap, mask):
  57. loss = torch.nn.functional.mse_loss(pred_heatmap, true_heatmap, reduction="none") * mask
  58. loss = loss.mean()
  59. return loss
  60. def offset_loss(self, pred_offsets, true_offsets, weights):
  61. num_pos = torch.nonzero(weights > 0).size()[0]
  62. loss = torch.nn.functional.smooth_l1_loss(pred_offsets, true_offsets, reduction="none", beta=1.0 / 9) * weights
  63. if num_pos == 0:
  64. num_pos = 1.0
  65. loss = loss.sum() / num_pos
  66. return loss
Discard
Tip!

Press p or to see the previous file or, n or to see the next file