Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dice_ce_edge_loss.py 5.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.modules.loss import _Loss
  4. from typing import Union, Tuple
  5. from super_gradients.training.losses.dice_loss import DiceLoss, BinaryDiceLoss
  6. from super_gradients.training.utils.segmentation_utils import target_to_binary_edge
  7. from super_gradients.common.object_names import Losses
  8. from super_gradients.common.registry.registry import register_loss
  9. from super_gradients.training.losses.mask_loss import MaskAttentionLoss
  10. @register_loss(name=Losses.DICE_CE_EDGE_LOSS, deprecated_name="dice_ce_edge_loss")
  11. class DiceCEEdgeLoss(_Loss):
  12. def __init__(
  13. self,
  14. num_classes: int,
  15. num_aux_heads: int = 2,
  16. num_detail_heads: int = 1,
  17. weights: Union[tuple, list] = (1, 1, 1, 1),
  18. dice_ce_weights: Union[tuple, list] = (1, 1),
  19. ignore_index: int = -100,
  20. edge_kernel: int = 3,
  21. ce_edge_weights: Union[tuple, list] = (0.5, 0.5),
  22. ):
  23. """
  24. Total loss is computed as follows:
  25. Loss-cls-edge = λ1 * CE + λ2 * M * CE , where [λ1, λ2] are ce_edge_weights.
  26. For each Main feature maps and auxiliary heads the loss is calculated as:
  27. Loss-main-aux = λ3 * Loss-cls-edge + λ4 * Loss-Dice, where [λ3, λ4] are dice_ce_weights.
  28. For Feature maps defined as detail maps that predicts only the edge mask, the loss is computed as follow:
  29. Loss-detail = BinaryCE + BinaryDice
  30. Finally the total loss is computed as follows for the whole feature maps:
  31. Loss = Σw[i] * Loss-main-aux[i] + Σw[j] * Loss-detail[j], where `w` is defined as the `weights` argument
  32. `i` in [0, 1 + num_aux_heads], 1 is for the main feature map.
  33. `j` in [1 + num_aux_heads, 1 + num_aux_heads + num_detail_heads].
  34. :param num_aux_heads: num of auxiliary heads.
  35. :param num_detail_heads: num of detail heads.
  36. :param weights: Loss lambda weights.
  37. :param dice_ce_weights: weights lambdas between (Dice, CE) losses.
  38. :param edge_kernel: kernel size of dilation erosion convolutions for creating the edge feature map.
  39. :param ce_edge_weights: weights lambdas between regular CE and edge attention CE.
  40. """
  41. super().__init__()
  42. # Check that arguments are valid.
  43. assert len(weights) == num_aux_heads + num_detail_heads + 1, "Lambda loss weights must be in same size as loss items."
  44. assert len(dice_ce_weights) == 2, f"dice_ce_weights must an iterable with size 2, found: {len(dice_ce_weights)}"
  45. assert len(ce_edge_weights) == 2, f"dice_ce_weights must an iterable with size 2, found: {len(ce_edge_weights)}"
  46. self.edge_kernel = edge_kernel
  47. self.num_classes = num_classes
  48. self.ignore_index = ignore_index
  49. self.weights = weights
  50. self.dice_ce_weights = dice_ce_weights
  51. self.use_detail = num_detail_heads > 0
  52. self.num_aux_heads = num_aux_heads
  53. self.num_detail_heads = num_detail_heads
  54. if self.use_detail:
  55. self.bce = nn.BCEWithLogitsLoss()
  56. self.binary_dice = BinaryDiceLoss(apply_sigmoid=True)
  57. self.ce_edge = MaskAttentionLoss(criterion=nn.CrossEntropyLoss(reduction="none", ignore_index=ignore_index), loss_weights=ce_edge_weights)
  58. self.dice_loss = DiceLoss(apply_softmax=True, ignore_index=None if ignore_index < 0 else ignore_index)
  59. @property
  60. def component_names(self):
  61. """
  62. Component names for logging during training.
  63. These correspond to 2nd item in the tuple returned in self.forward(...).
  64. See super_gradients.Trainer.train() docs for more info.
  65. """
  66. names = ["main_loss"]
  67. # Append aux losses names
  68. names += [f"aux_loss{i}" for i in range(self.num_aux_heads)]
  69. # Append detail losses names
  70. names += [f"detail_loss{i}" for i in range(self.num_detail_heads)]
  71. names += ["loss"]
  72. return names
  73. def forward(self, preds: Tuple[torch.Tensor], target: torch.Tensor):
  74. """
  75. :param preds: Model output predictions, must be in the followed format:
  76. [Main-feats, Aux-feats[0], ..., Aux-feats[num_auxs-1], Detail-feats[0], ..., Detail-feats[num_details-1]
  77. """
  78. assert (
  79. len(preds) == self.num_aux_heads + self.num_detail_heads + 1
  80. ), f"Wrong num of predictions tensors, expected {self.num_aux_heads + self.num_detail_heads + 1} found {len(preds)}"
  81. edge_target = target_to_binary_edge(
  82. target, num_classes=self.num_classes, kernel_size=self.edge_kernel, ignore_index=self.ignore_index, flatten_channels=True
  83. )
  84. losses = []
  85. total_loss = 0
  86. # Main and auxiliaries feature maps losses
  87. for i in range(0, 1 + self.num_aux_heads):
  88. ce_loss = self.ce_edge(preds[i], target, edge_target)
  89. dice_loss = self.dice_loss(preds[i], target)
  90. loss = ce_loss * self.dice_ce_weights[0] + dice_loss * self.dice_ce_weights[1]
  91. total_loss += self.weights[i] * loss
  92. losses.append(loss)
  93. # Detail feature maps losses
  94. if self.use_detail:
  95. for i in range(1 + self.num_aux_heads, len(preds)):
  96. bce_loss = self.bce(preds[i], edge_target)
  97. dice_loss = self.binary_dice(preds[i], edge_target)
  98. loss = bce_loss * self.dice_ce_weights[0] + dice_loss * self.dice_ce_weights[1]
  99. total_loss += self.weights[i] * loss
  100. losses.append(loss)
  101. losses.append(total_loss)
  102. return total_loss, torch.stack(losses, dim=0).detach()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...