Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. import torch
  2. from torch import nn
  3. from torch.nn.modules.loss import _Loss
  4. from super_gradients.training.exceptions.loss_exceptions import IllegalRangeForLossAttributeException, RequiredLossComponentReductionException
  5. class OhemLoss(_Loss):
  6. """
  7. OhemLoss - Online Hard Example Mining Cross Entropy Loss
  8. """
  9. def __init__(self,
  10. threshold: float,
  11. mining_percent: float = 0.1,
  12. ignore_lb: int = -100,
  13. num_pixels_exclude_ignored: bool = True,
  14. criteria: _Loss = None):
  15. """
  16. :param threshold: Sample below probability threshold, is considered hard.
  17. :param num_pixels_exclude_ignored: How to calculate total pixels from which extract mining percent of the
  18. samples.
  19. :param ignore_lb: label index to be ignored in loss calculation.
  20. :param criteria: loss to mine the examples from.
  21. i.e for num_pixels=100, ignore_pixels=30, mining_percent=0.1:
  22. num_pixels_exclude_ignored=False => num_mining = 100 * 0.1 = 10
  23. num_pixels_exclude_ignored=True => num_mining = (100 - 30) * 0.1 = 7
  24. """
  25. super().__init__()
  26. if mining_percent < 0 or mining_percent > 1:
  27. raise IllegalRangeForLossAttributeException((0, 1), "mining percent")
  28. self.thresh = -torch.log(torch.tensor(threshold, dtype=torch.float))
  29. self.mining_percent = mining_percent
  30. self.ignore_lb = ignore_lb
  31. self.num_pixels_exclude_ignored = num_pixels_exclude_ignored
  32. if criteria.reduction != 'none':
  33. raise RequiredLossComponentReductionException("criteria", criteria.reduction, 'none')
  34. self.criteria = criteria
  35. def forward(self, logits, labels):
  36. loss = self.criteria(logits, labels).view(-1)
  37. if self.num_pixels_exclude_ignored:
  38. # remove ignore label elements
  39. loss = loss[labels.view(-1) != self.ignore_lb]
  40. # num pixels in a batch -> num_pixels = batch_size * width * height - ignore_pixels
  41. num_pixels = loss.numel()
  42. else:
  43. num_pixels = labels.numel()
  44. # if all pixels are ignore labels, return empty loss tensor
  45. if num_pixels == 0:
  46. return torch.tensor([0.]).requires_grad_(True)
  47. num_mining = int(self.mining_percent * num_pixels)
  48. # in case mining_percent=1, prevent out of bound exception
  49. num_mining = min(num_mining, num_pixels - 1)
  50. self.thresh = self.thresh.to(logits.device)
  51. loss, _ = torch.sort(loss, descending=True)
  52. if loss[num_mining] > self.thresh:
  53. loss = loss[loss > self.thresh]
  54. else:
  55. loss = loss[:num_mining]
  56. return torch.mean(loss)
  57. class OhemCELoss(OhemLoss):
  58. """
  59. OhemLoss - Online Hard Example Mining Cross Entropy Loss
  60. """
  61. def __init__(self,
  62. threshold: float,
  63. mining_percent: float = 0.1,
  64. ignore_lb: int = -100,
  65. num_pixels_exclude_ignored: bool = True):
  66. ignore_lb = -100 if ignore_lb is None or ignore_lb < 0 else ignore_lb
  67. criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
  68. super(OhemCELoss, self).__init__(threshold=threshold,
  69. mining_percent=mining_percent,
  70. ignore_lb=ignore_lb,
  71. num_pixels_exclude_ignored=num_pixels_exclude_ignored,
  72. criteria=criteria)
  73. class OhemBCELoss(OhemLoss):
  74. """
  75. OhemBCELoss - Online Hard Example Mining Binary Cross Entropy Loss
  76. """
  77. def __init__(self,
  78. threshold: float,
  79. mining_percent: float = 0.1,
  80. ignore_lb: int = -100,
  81. num_pixels_exclude_ignored: bool = True, ):
  82. super(OhemBCELoss, self).__init__(threshold=threshold,
  83. mining_percent=mining_percent,
  84. ignore_lb=ignore_lb,
  85. num_pixels_exclude_ignored=num_pixels_exclude_ignored,
  86. criteria=nn.BCEWithLogitsLoss(reduction='none'))
  87. def forward(self, logits, labels):
  88. # REMOVE SINGLE CLASS CHANNEL WHEN DEALING WITH BINARY DATA
  89. if logits.shape[1] == 1:
  90. logits = logits.squeeze(1)
  91. return super(OhemBCELoss, self).forward(logits, labels.float())
Discard
Tip!

Press p or to see the previous file or, n or to see the next file