Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#657 Segmentation Readme

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-594-segmentation_readme
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. import torch
  2. import unittest
  3. import torch.nn.functional as F
  4. from super_gradients.training.losses.ohem_ce_loss import OhemCELoss
  5. class OhemLossTest(unittest.TestCase):
  6. def setUp(self) -> None:
  7. self.img_size = 64
  8. self.eps = 0.01
  9. def test_all_hard_no_mining(self):
  10. # equal probability distribution, p = 1 / num_classes
  11. # except loss to be: -log(p)
  12. num_classes = 19
  13. targets = torch.randint(0, num_classes, (1, self.img_size, self.img_size))
  14. predictions = torch.ones((1, num_classes, self.img_size, self.img_size))
  15. probability = 1 / num_classes
  16. # All samples are hard, No Hard-mining
  17. criterion = OhemCELoss(threshold=probability + self.eps, mining_percent=0.1)
  18. expected_loss = -torch.log(torch.tensor(probability))
  19. loss = criterion(predictions, targets)
  20. self.assertAlmostEqual(expected_loss, loss, delta=1e-5)
  21. def test_hard_mining(self):
  22. num_classes = 2
  23. predictions = torch.ones((1, num_classes, self.img_size, self.img_size))
  24. targets = torch.randint(0, num_classes, (1, self.img_size, self.img_size))
  25. # create hard samples
  26. hard_class = 0
  27. mask = targets == hard_class
  28. predictions[:, hard_class, mask.squeeze()] = 0.0
  29. hard_percent = mask.sum() / targets.numel()
  30. predicted_prob = F.softmax(torch.tensor([0.0, 1.0]), dim=0)[0].item()
  31. criterion = OhemCELoss(threshold=predicted_prob + self.eps, mining_percent=hard_percent)
  32. expected_loss = -torch.log(torch.tensor(predicted_prob))
  33. loss = criterion(predictions, targets)
  34. self.assertAlmostEqual(expected_loss, loss, delta=1e-5)
  35. def test_ignore_label(self):
  36. num_classes = 2
  37. predictions = torch.ones((1, num_classes, self.img_size, self.img_size))
  38. targets = torch.randint(0, num_classes, (1, self.img_size, self.img_size))
  39. # create hard samples, to be ignored later
  40. hard_class = 0
  41. mask = targets == hard_class
  42. predictions[:, hard_class, mask.squeeze()] = 0.0
  43. # except loss to be an equal distribution, w.r.t ignoring the hard label
  44. predicted_prob = F.softmax(torch.tensor([1.0, 1.0]), dim=0)[0].item()
  45. criterion = OhemCELoss(threshold=predicted_prob + self.eps, mining_percent=1.0, ignore_lb=hard_class)
  46. expected_loss = -torch.log(torch.tensor(predicted_prob))
  47. loss = criterion(predictions, targets)
  48. self.assertAlmostEqual(expected_loss, loss, delta=1e-5)
  49. def test_all_are_ignore_label(self):
  50. num_classes = 2
  51. predictions = torch.ones((1, num_classes, self.img_size, self.img_size))
  52. targets = torch.zeros(1, self.img_size, self.img_size).long() # all targets are 0 class
  53. ignore_class = 0
  54. criterion = OhemCELoss(threshold=0.5, mining_percent=1.0, ignore_lb=ignore_class)
  55. expected_loss = 0.0 # except empty zero tensor, because all are ignore labels
  56. loss = criterion(predictions, targets)
  57. self.assertAlmostEqual(expected_loss, loss, delta=1e-5)
  58. if __name__ == "__main__":
  59. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file