Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  1. import torch
  2. import torch.nn as nn
  3. from super_gradients.training.losses.dice_loss import DiceLoss, BinaryDiceLoss
  4. from super_gradients.training.utils.segmentation_utils import target_to_binary_edge
  5. from torch.nn.modules.loss import _Loss
  6. from typing import Union, Tuple
  7. from super_gradients.training.losses.mask_loss import MaskAttentionLoss
  8. class DiceCEEdgeLoss(_Loss):
  9. def __init__(
  10. self,
  11. num_classes: int,
  12. num_aux_heads: int = 2,
  13. num_detail_heads: int = 1,
  14. weights: Union[tuple, list] = (1, 1, 1, 1),
  15. dice_ce_weights: Union[tuple, list] = (1, 1),
  16. ignore_index: int = -100,
  17. edge_kernel: int = 3,
  18. ce_edge_weights: Union[tuple, list] = (0.5, 0.5),
  19. ):
  20. """
  21. Total loss is computed as follows:
  22. Loss-cls-edge = λ1 * CE + λ2 * M * CE , where [λ1, λ2] are ce_edge_weights.
  23. For each Main feature maps and auxiliary heads the loss is calculated as:
  24. Loss-main-aux = λ3 * Loss-cls-edge + λ4 * Loss-Dice, where [λ3, λ4] are dice_ce_weights.
  25. For Feature maps defined as detail maps that predicts only the edge mask, the loss is computed as follow:
  26. Loss-detail = BinaryCE + BinaryDice
  27. Finally the total loss is computed as follows for the whole feature maps:
  28. Loss = Σw[i] * Loss-main-aux[i] + Σw[j] * Loss-detail[j], where `w` is defined as the `weights` argument
  29. `i` in [0, 1 + num_aux_heads], 1 is for the main feature map.
  30. `j` in [1 + num_aux_heads, 1 + num_aux_heads + num_detail_heads].
  31. :param num_aux_heads: num of auxiliary heads.
  32. :param num_detail_heads: num of detail heads.
  33. :param weights: Loss lambda weights.
  34. :param dice_ce_weights: weights lambdas between (Dice, CE) losses.
  35. :param edge_kernel: kernel size of dilation erosion convolutions for creating the edge feature map.
  36. :param ce_edge_weights: weights lambdas between regular CE and edge attention CE.
  37. """
  38. super().__init__()
  39. # Check that arguments are valid.
  40. assert len(weights) == num_aux_heads + num_detail_heads + 1, "Lambda loss weights must be in same size as loss items."
  41. assert len(dice_ce_weights) == 2, f"dice_ce_weights must an iterable with size 2, found: {len(dice_ce_weights)}"
  42. assert len(ce_edge_weights) == 2, f"dice_ce_weights must an iterable with size 2, found: {len(ce_edge_weights)}"
  43. self.edge_kernel = edge_kernel
  44. self.num_classes = num_classes
  45. self.ignore_index = ignore_index
  46. self.weights = weights
  47. self.dice_ce_weights = dice_ce_weights
  48. self.use_detail = num_detail_heads > 0
  49. self.num_aux_heads = num_aux_heads
  50. self.num_detail_heads = num_detail_heads
  51. if self.use_detail:
  52. self.bce = nn.BCEWithLogitsLoss()
  53. self.binary_dice = BinaryDiceLoss(apply_sigmoid=True)
  54. self.ce_edge = MaskAttentionLoss(criterion=nn.CrossEntropyLoss(reduction="none", ignore_index=ignore_index), loss_weights=ce_edge_weights)
  55. self.dice_loss = DiceLoss(apply_softmax=True, ignore_index=None if ignore_index < 0 else ignore_index)
  56. @property
  57. def component_names(self):
  58. """
  59. Component names for logging during training.
  60. These correspond to 2nd item in the tuple returned in self.forward(...).
  61. See super_gradients.Trainer.train() docs for more info.
  62. """
  63. names = ["main_loss"]
  64. # Append aux losses names
  65. names += [f"aux_loss{i}" for i in range(self.num_aux_heads)]
  66. # Append detail losses names
  67. names += [f"detail_loss{i}" for i in range(self.num_detail_heads)]
  68. names += ["loss"]
  69. return names
  70. def forward(self, preds: Tuple[torch.Tensor], target: torch.Tensor):
  71. """
  72. :param preds: Model output predictions, must be in the followed format:
  73. [Main-feats, Aux-feats[0], ..., Aux-feats[num_auxs-1], Detail-feats[0], ..., Detail-feats[num_details-1]
  74. """
  75. assert (
  76. len(preds) == self.num_aux_heads + self.num_detail_heads + 1
  77. ), f"Wrong num of predictions tensors, expected {self.num_aux_heads + self.num_detail_heads + 1} found {len(preds)}"
  78. edge_target = target_to_binary_edge(
  79. target, num_classes=self.num_classes, kernel_size=self.edge_kernel, ignore_index=self.ignore_index, flatten_channels=True
  80. )
  81. losses = []
  82. total_loss = 0
  83. # Main and auxiliaries feature maps losses
  84. for i in range(0, 1 + self.num_aux_heads):
  85. ce_loss = self.ce_edge(preds[i], target, edge_target)
  86. dice_loss = self.dice_loss(preds[i], target)
  87. loss = ce_loss * self.dice_ce_weights[0] + dice_loss * self.dice_ce_weights[1]
  88. total_loss += self.weights[i] * loss
  89. losses.append(loss)
  90. # Detail feature maps losses
  91. if self.use_detail:
  92. for i in range(1 + self.num_aux_heads, len(preds)):
  93. bce_loss = self.bce(preds[i], edge_target)
  94. dice_loss = self.binary_dice(preds[i], edge_target)
  95. loss = bce_loss * self.dice_ce_weights[0] + dice_loss * self.dice_ce_weights[1]
  96. total_loss += self.weights[i] * loss
  97. losses.append(loss)
  98. losses.append(total_loss)
  99. return total_loss, torch.stack(losses, dim=0).detach()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file