Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  1. from abc import ABC, abstractmethod
  2. from typing import Union, Optional
  3. import torch
  4. from torch.nn.modules.loss import _Loss
  5. from super_gradients.common.abstractions.abstract_logger import get_logger
  6. from super_gradients.training.losses.loss_utils import apply_reduce, LossReduction
  7. from super_gradients.training.utils.segmentation_utils import to_one_hot
  8. logger = get_logger(__name__)
  9. class AbstarctSegmentationStructureLoss(_Loss, ABC):
  10. """
  11. Abstract computation of structure loss between two tensors, It can support both multi-classes and binary tasks.
  12. """
  13. def __init__(
  14. self,
  15. apply_softmax: bool = True,
  16. ignore_index: int = None,
  17. smooth: float = 1.0,
  18. eps: float = 1e-5,
  19. reduce_over_batches: bool = False,
  20. generalized_metric: bool = False,
  21. weight: Optional[torch.Tensor] = None,
  22. reduction: Union[LossReduction, str] = "mean",
  23. ):
  24. """
  25. :param apply_softmax: Whether to apply softmax to the predictions.
  26. :param smooth: laplace smoothing, also known as additive smoothing. The larger smooth value is, closer the metric
  27. coefficient is to 1, which can be used as a regularization effect.
  28. As mentioned in: https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
  29. :param eps: epsilon value to avoid inf.
  30. :param reduce_over_batches: Whether to average metric over the batch axis if set True,
  31. default is `False` to average over the classes axis.
  32. :param generalized_metric: Whether to apply normalization by the volume of each class.
  33. :param weight: a manual rescaling weight given to each class. If given, it has to be a Tensor of size `C`.
  34. :param reduction: Specifies the reduction to apply to the output: `none` | `mean` | `sum`.
  35. `none`: no reduction will be applied.
  36. `mean`: the sum of the output will be divided by the number of elements in the output.
  37. `sum`: the output will be summed.
  38. Default: `mean`
  39. """
  40. super().__init__(reduction=reduction)
  41. self.ignore_index = ignore_index
  42. self.apply_softmax = apply_softmax
  43. self.eps = eps
  44. self.smooth = smooth
  45. self.reduce_over_batches = reduce_over_batches
  46. self.generalized_metric = generalized_metric
  47. self.weight = weight
  48. if self.generalized_metric:
  49. assert self.weight is None, "Cannot use structured Loss with weight classes and generalized normalization"
  50. if self.eps > 1e-12:
  51. logger.warning("When using GeneralizedLoss, it is recommended to use eps below 1e-12, to not affect" "small values normalized terms.")
  52. if self.smooth != 0:
  53. logger.warning("When using GeneralizedLoss, it is recommended to set smooth value as 0.")
  54. @abstractmethod
  55. def _calc_numerator_denominator(self, labels_one_hot, predict) -> (torch.Tensor, torch.Tensor):
  56. """
  57. All base classes must implement this function.
  58. Return: 2 tensor of shape [BS, num_classes, img_width, img_height].
  59. """
  60. raise NotImplementedError()
  61. @abstractmethod
  62. def _calc_loss(self, numerator, denominator) -> torch.Tensor:
  63. """
  64. All base classes must implement this function.
  65. Return a tensors of shape [BS] if self.reduce_over_batches else [num_classes].
  66. """
  67. raise NotImplementedError()
  68. def forward(self, predict, target):
  69. if self.apply_softmax:
  70. predict = torch.softmax(predict, dim=1)
  71. # target to one hot format
  72. if target.size() == predict.size():
  73. labels_one_hot = target
  74. elif target.dim() == 3: # if target tensor is in class indexes format.
  75. if predict.size(1) == 1 and self.ignore_index is None: # if one class prediction task
  76. labels_one_hot = target.unsqueeze(1)
  77. else:
  78. labels_one_hot = to_one_hot(target, num_classes=predict.shape[1], ignore_index=self.ignore_index)
  79. else:
  80. raise AssertionError(
  81. f"Mismatch of target shape: {target.size()} and prediction shape: {predict.size()},"
  82. f" target must be [NxWxH] tensor for to_one_hot conversion"
  83. f" or to have the same num of channels like prediction tensor"
  84. )
  85. reduce_spatial_dims = list(range(2, len(predict.shape)))
  86. reduce_dims = [1] + reduce_spatial_dims if self.reduce_over_batches else [0] + reduce_spatial_dims
  87. # Calculate the numerator and denominator of the chosen metric
  88. numerator, denominator = self._calc_numerator_denominator(labels_one_hot, predict)
  89. # exclude ignore labels from numerator and denominator, false positive predicted on ignore samples
  90. # are not included in the total calculation.
  91. if self.ignore_index is not None:
  92. valid_mask = target.ne(self.ignore_index).unsqueeze(1).expand_as(denominator)
  93. numerator *= valid_mask
  94. denominator *= valid_mask
  95. numerator = torch.sum(numerator, dim=reduce_dims)
  96. denominator = torch.sum(denominator, dim=reduce_dims)
  97. if self.generalized_metric:
  98. weights = 1.0 / (torch.sum(labels_one_hot, dim=reduce_dims) ** 2)
  99. # if some classes are not in batch, weights will be inf.
  100. infs = torch.isinf(weights)
  101. weights[infs] = 0.0
  102. numerator *= weights
  103. denominator *= weights
  104. # Calculate the loss of the chosen metric
  105. losses = self._calc_loss(numerator, denominator)
  106. if self.weight is not None:
  107. losses *= self.weight
  108. return apply_reduce(losses, reduction=self.reduction)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file