Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

segmentation_metrics.py 4.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  1. import numpy as np
  2. import torch
  3. import torchmetrics
  4. from torchmetrics import Metric
  5. def batch_pix_accuracy(predict, target):
  6. """Batch Pixel Accuracy
  7. Args:
  8. predict: input 4D tensor
  9. target: label 3D tensor
  10. """
  11. _, predict = torch.max(predict, 1)
  12. predict = predict.cpu().numpy() + 1
  13. target = target.cpu().numpy() + 1
  14. pixel_labeled = np.sum(target > 0)
  15. pixel_correct = np.sum((predict == target) * (target > 0))
  16. assert pixel_correct <= pixel_labeled, \
  17. "Correct area should be smaller than Labeled"
  18. return pixel_correct, pixel_labeled
  19. def batch_intersection_union(predict, target, nclass):
  20. """Batch Intersection of Union
  21. Args:
  22. predict: input 4D tensor
  23. target: label 3D tensor
  24. nclass: number of categories (int)
  25. """
  26. _, predict = torch.max(predict, 1)
  27. mini = 1
  28. maxi = nclass
  29. nbins = nclass
  30. predict = predict.cpu().numpy() + 1
  31. target = target.cpu().numpy() + 1
  32. predict = predict * (target > 0).astype(predict.dtype)
  33. intersection = predict * (predict == target)
  34. # areas of intersection and union
  35. area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
  36. area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
  37. area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
  38. area_union = area_pred + area_lab - area_inter
  39. assert (area_inter <= area_union).all(), \
  40. "Intersection area should be smaller than Union area"
  41. return area_inter, area_union
  42. # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
  43. def pixel_accuracy(im_pred, im_lab):
  44. im_pred = np.asarray(im_pred)
  45. im_lab = np.asarray(im_lab)
  46. # Remove classes from unlabeled pixels in gt image.
  47. # We should not penalize detections in unlabeled portions of the image.
  48. pixel_labeled = np.sum(im_lab > 0)
  49. pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
  50. # pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
  51. return pixel_correct, pixel_labeled
  52. def intersection_and_union(im_pred, im_lab, num_class):
  53. im_pred = np.asarray(im_pred)
  54. im_lab = np.asarray(im_lab)
  55. # Remove classes from unlabeled pixels in gt image.
  56. im_pred = im_pred * (im_lab > 0)
  57. # Compute area intersection:
  58. intersection = im_pred * (im_pred == im_lab)
  59. area_inter, _ = np.histogram(intersection, bins=num_class - 1,
  60. range=(1, num_class - 1))
  61. # Compute area union:
  62. area_pred, _ = np.histogram(im_pred, bins=num_class - 1,
  63. range=(1, num_class - 1))
  64. area_lab, _ = np.histogram(im_lab, bins=num_class - 1,
  65. range=(1, num_class - 1))
  66. area_union = area_pred + area_lab - area_inter
  67. return area_inter, area_union
  68. class PixelAccuracy(Metric):
  69. def __init__(self, ignore_label=-100, dist_sync_on_step=False):
  70. super().__init__(dist_sync_on_step=dist_sync_on_step)
  71. self.ignore_label = ignore_label
  72. self.add_state("total_correct", default=torch.tensor(0.), dist_reduce_fx="sum")
  73. self.add_state("total_label", default=torch.tensor(0.), dist_reduce_fx="sum")
  74. def update(self, preds: torch.Tensor, target: torch.Tensor):
  75. if isinstance(preds, tuple):
  76. preds = preds[0]
  77. _, predict = torch.max(preds, 1)
  78. labeled_mask = target.ne(self.ignore_label)
  79. pixel_labeled = torch.sum(labeled_mask)
  80. pixel_correct = torch.sum((predict == target) * labeled_mask)
  81. self.total_correct += pixel_correct
  82. self.total_label += pixel_labeled
  83. def compute(self):
  84. _total_correct = self.total_correct.cpu().detach().numpy().astype('int64')
  85. _total_label = self.total_label.cpu().detach().numpy().astype('int64')
  86. pix_acc = np.float64(1.0) * _total_correct / (np.spacing(1, dtype=np.float64) + _total_label)
  87. return pix_acc
  88. class IoU(torchmetrics.IoU):
  89. def __init__(self, num_classes, dist_sync_on_step=True, ignore_index=None):
  90. super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index)
  91. def update(self, preds, target: torch.Tensor):
  92. # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
  93. if isinstance(preds, tuple):
  94. preds = preds[0]
  95. _, preds = torch.max(preds, 1)
  96. super().update(preds=preds, target=target)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...