1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
- import numpy as np
- import torch
- import torchmetrics
- from torchmetrics import Metric
- def batch_pix_accuracy(predict, target):
- """Batch Pixel Accuracy
- Args:
- predict: input 4D tensor
- target: label 3D tensor
- """
- _, predict = torch.max(predict, 1)
- predict = predict.cpu().numpy() + 1
- target = target.cpu().numpy() + 1
- pixel_labeled = np.sum(target > 0)
- pixel_correct = np.sum((predict == target) * (target > 0))
- assert pixel_correct <= pixel_labeled, \
- "Correct area should be smaller than Labeled"
- return pixel_correct, pixel_labeled
- def batch_intersection_union(predict, target, nclass):
- """Batch Intersection of Union
- Args:
- predict: input 4D tensor
- target: label 3D tensor
- nclass: number of categories (int)
- """
- _, predict = torch.max(predict, 1)
- mini = 1
- maxi = nclass
- nbins = nclass
- predict = predict.cpu().numpy() + 1
- target = target.cpu().numpy() + 1
- predict = predict * (target > 0).astype(predict.dtype)
- intersection = predict * (predict == target)
- # areas of intersection and union
- area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
- area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
- area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
- area_union = area_pred + area_lab - area_inter
- assert (area_inter <= area_union).all(), \
- "Intersection area should be smaller than Union area"
- return area_inter, area_union
- # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
- def pixel_accuracy(im_pred, im_lab):
- im_pred = np.asarray(im_pred)
- im_lab = np.asarray(im_lab)
- # Remove classes from unlabeled pixels in gt image.
- # We should not penalize detections in unlabeled portions of the image.
- pixel_labeled = np.sum(im_lab > 0)
- pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
- # pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
- return pixel_correct, pixel_labeled
- def intersection_and_union(im_pred, im_lab, num_class):
- im_pred = np.asarray(im_pred)
- im_lab = np.asarray(im_lab)
- # Remove classes from unlabeled pixels in gt image.
- im_pred = im_pred * (im_lab > 0)
- # Compute area intersection:
- intersection = im_pred * (im_pred == im_lab)
- area_inter, _ = np.histogram(intersection, bins=num_class - 1,
- range=(1, num_class - 1))
- # Compute area union:
- area_pred, _ = np.histogram(im_pred, bins=num_class - 1,
- range=(1, num_class - 1))
- area_lab, _ = np.histogram(im_lab, bins=num_class - 1,
- range=(1, num_class - 1))
- area_union = area_pred + area_lab - area_inter
- return area_inter, area_union
- class PixelAccuracy(Metric):
- def __init__(self, ignore_label=-100, dist_sync_on_step=False):
- super().__init__(dist_sync_on_step=dist_sync_on_step)
- self.ignore_label = ignore_label
- self.add_state("total_correct", default=torch.tensor(0.), dist_reduce_fx="sum")
- self.add_state("total_label", default=torch.tensor(0.), dist_reduce_fx="sum")
- def update(self, preds: torch.Tensor, target: torch.Tensor):
- if isinstance(preds, tuple):
- preds = preds[0]
- _, predict = torch.max(preds, 1)
- labeled_mask = target.ne(self.ignore_label)
- pixel_labeled = torch.sum(labeled_mask)
- pixel_correct = torch.sum((predict == target) * labeled_mask)
- self.total_correct += pixel_correct
- self.total_label += pixel_labeled
- def compute(self):
- _total_correct = self.total_correct.cpu().detach().numpy().astype('int64')
- _total_label = self.total_label.cpu().detach().numpy().astype('int64')
- pix_acc = np.float64(1.0) * _total_correct / (np.spacing(1, dtype=np.float64) + _total_label)
- return pix_acc
- class IoU(torchmetrics.IoU):
- def __init__(self, num_classes, dist_sync_on_step=True, ignore_index=None):
- super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index)
- def update(self, preds, target: torch.Tensor):
- # WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
- if isinstance(preds, tuple):
- preds = preds[0]
- _, preds = torch.max(preds, 1)
- super().update(preds=preds, target=target)
|