Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#869 Add DagsHub Logger to Super Gradients

Merged
Ghost merged 1 commits into Deci-AI:master from timho102003:dagshub_logger
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
  1. from typing import Tuple
  2. import numpy as np
  3. import torch
  4. from torch import Tensor
  5. def compute_visible_bbox_xywh(joints: Tensor, visibility_mask: Tensor) -> np.ndarray:
  6. """
  7. Compute the bounding box (X,Y,W,H) of the visible joints for each instance.
  8. :param joints: [Num Instances, Num Joints, 2+] last channel must have dimension of
  9. at least 2 that is considered to contain (X,Y) coordinates of the keypoint
  10. :param visibility_mask: [Num Instances, Num Joints]
  11. :return: A numpy array [Num Instances, 4] where last dimension contains bbox in format XYWH
  12. """
  13. visibility_mask = visibility_mask > 0
  14. initial_value = 1_000_000
  15. x1 = torch.min(joints[:, :, 0], where=visibility_mask, initial=initial_value, dim=-1)
  16. y1 = torch.min(joints[:, :, 1], where=visibility_mask, initial=initial_value, dim=-1)
  17. x1[x1 == initial_value] = 0
  18. y1[y1 == initial_value] = 0
  19. x2 = torch.max(joints[:, :, 0], where=visibility_mask, initial=0, dim=-1)
  20. y2 = torch.max(joints[:, :, 1], where=visibility_mask, initial=0, dim=-1)
  21. w = x2 - x1
  22. h = y2 - y1
  23. return torch.stack([x1, y1, w, h], dim=-1)
  24. def compute_oks(
  25. pred_joints: Tensor,
  26. gt_joints: Tensor,
  27. gt_keypoint_visibility: Tensor,
  28. sigmas: Tensor,
  29. gt_areas: Tensor = None,
  30. gt_bboxes: Tensor = None,
  31. ) -> np.ndarray:
  32. """
  33. :param pred_joints: [K, NumJoints, 2] or [K, NumJoints, 3]
  34. :param pred_scores: [K]
  35. :param gt_joints: [M, NumJoints, 2]
  36. :param gt_keypoint_visibility: [M, NumJoints]
  37. :param gt_areas: [M] Area of each ground truth instance. COCOEval uses area of the instance mask to scale OKs, so it must be provided separately.
  38. If None, we will use area of bounding box of each instance computed from gt_joints.
  39. :param gt_bboxes: [M, 4] Bounding box (X,Y,W,H) of each ground truth instance. If None, we will use bounding box of each instance computed from gt_joints.
  40. :param sigmas: [NumJoints]
  41. :return: IoU matrix [K, M]
  42. """
  43. ious = torch.zeros((len(pred_joints), len(gt_joints)), device=pred_joints.device)
  44. vars = (sigmas * 2) ** 2
  45. if gt_bboxes is None:
  46. gt_bboxes = compute_visible_bbox_xywh(gt_joints, gt_keypoint_visibility)
  47. if gt_areas is None:
  48. gt_areas = gt_bboxes[:, 2] * gt_bboxes[:, 3]
  49. # compute oks between each detection and ground truth object
  50. for gt_index, (gt_keypoints, gt_keypoint_visibility, gt_bbox, gt_area) in enumerate(zip(gt_joints, gt_keypoint_visibility, gt_bboxes, gt_areas)):
  51. # create bounds for ignore regions(double the gt bbox)
  52. xg = gt_keypoints[:, 0]
  53. yg = gt_keypoints[:, 1]
  54. k1 = torch.count_nonzero(gt_keypoint_visibility > 0)
  55. x0 = gt_bbox[0] - gt_bbox[2]
  56. x1 = gt_bbox[0] + gt_bbox[2] * 2
  57. y0 = gt_bbox[1] - gt_bbox[3]
  58. y1 = gt_bbox[1] + gt_bbox[3] * 2
  59. for pred_index, pred_keypoints in enumerate(pred_joints):
  60. xd = pred_keypoints[:, 0]
  61. yd = pred_keypoints[:, 1]
  62. if k1 > 0:
  63. # measure the per-keypoint distance if keypoints visible
  64. dx = xd - xg
  65. dy = yd - yg
  66. else:
  67. # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
  68. dx = (x0 - xd).clamp_min(0) + (xd - x1).clamp_min(0)
  69. dy = (y0 - yd).clamp_min(0) + (yd - y1).clamp_min(0)
  70. e = (dx**2 + dy**2) / vars / (gt_area + torch.finfo(torch.float64).eps) / 2
  71. if k1 > 0:
  72. e = e[gt_keypoint_visibility > 0]
  73. ious[pred_index, gt_index] = torch.sum(torch.exp(-e)) / e.shape[0]
  74. return ious
  75. def compute_img_keypoint_matching(
  76. preds: Tensor,
  77. pred_scores: Tensor,
  78. targets: Tensor,
  79. targets_visibilities: Tensor,
  80. targets_areas: Tensor,
  81. targets_bboxes: Tensor,
  82. targets_ignored: Tensor,
  83. crowd_targets: Tensor,
  84. crowd_visibilities: Tensor,
  85. crowd_targets_areas: Tensor,
  86. crowd_targets_bboxes: Tensor,
  87. iou_thresholds: torch.Tensor,
  88. sigmas: Tensor,
  89. top_k: int,
  90. ) -> Tuple[Tensor, Tensor, Tensor, int]:
  91. """
  92. Match predictions and the targets (ground truth) with respect to IoU and confidence score for a given image.
  93. :param preds: Tensor of shape (K, NumJoints, 3) - Array of predicted skeletons.
  94. Last dimension encode X,Y and confidence score of each joint
  95. :param pred_scores: Tensor of shape (K) - Confidence scores for each pose
  96. :param targets: Targets joints (M, NumJoints, 2) - Array of groundtruth skeletons
  97. :param targets_visibilities: Visibility status for each keypoint (M, NumJoints).
  98. Values are 0 - invisible, 1 - occluded, 2 - fully visible
  99. :param targets_areas: Tensor of shape (M) - Areas of target objects
  100. :param targets_bboxes: Tensor of shape (M,4) - Bounding boxes (XYWH) of targets
  101. :param targets_ignored: Tensor of shape (M) - Array of target that marked as ignored
  102. (E.g all keypoints are not visible or target does not fit the desired area range)
  103. :param crowd_targets: Targets joints (Mc, NumJoints, 3) - Array of groundtruth skeletons
  104. Last dimension encode X,Y and visibility score of each joint:
  105. (0 - invisible, 1 - occluded, 2 - fully visible)
  106. :param crowd_visibilities: Visibility status for each keypoint of crowd targets (Mc, NumJoints).
  107. Values are 0 - invisible, 1 - occluded, 2 - fully visible
  108. :param crowd_targets_areas: Tensor of shape (Mc) - Areas of target objects
  109. :param crowd_targets_bboxes: Tensor of shape (Mc, 4) - Bounding boxes (XYWH) of crowd targets
  110. :param iou_thresholds: IoU Threshold to compute the mAP
  111. :param sigmas: Tensor of shape (NumJoints) with sigmas for each joint. Sigma value represent how 'hard'
  112. it is to locate the exact groundtruth position of the joint.
  113. :param top_k: Number of predictions to keep, ordered by confidence score
  114. :return:
  115. :preds_matched: Tensor of shape (min(top_k, len(preds)), n_iou_thresholds)
  116. True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
  117. :preds_to_ignore: Tensor of shape (min(top_k, len(preds)), n_iou_thresholds)
  118. True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
  119. :preds_scores: Tensor of shape (min(top_k, len(preds))) with scores of top-k predictions
  120. :num_targets: Number of groundtruth targets (total num targets minus number of ignored)
  121. """
  122. num_iou_thresholds = len(iou_thresholds)
  123. device = preds.device if torch.is_tensor(preds) else (targets.device if torch.is_tensor(targets) else "cpu")
  124. if preds is None or len(preds) == 0:
  125. preds_matched = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
  126. preds_to_ignore = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
  127. preds_scores = torch.zeros((0,), dtype=torch.float, device=device)
  128. return preds_matched, preds_to_ignore, preds_scores, len(targets)
  129. preds_matched = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
  130. targets_matched = torch.zeros(len(targets), num_iou_thresholds, dtype=torch.bool, device=device)
  131. preds_to_ignore = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
  132. # Ignore all but the predictions that were top_k
  133. k = min(top_k, len(pred_scores))
  134. preds_idx_to_use = torch.topk(pred_scores, k=k, sorted=True, largest=True).indices
  135. preds_to_ignore[:, :] = True
  136. preds_to_ignore[preds_idx_to_use] = False
  137. if len(targets) > 0:
  138. iou = compute_oks(preds[preds_idx_to_use], targets, targets_visibilities, sigmas, gt_areas=targets_areas, gt_bboxes=targets_bboxes)
  139. # The matching priority is first detection confidence and then IoU value.
  140. # The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.
  141. sorted_iou, target_sorted = iou.sort(descending=True, stable=True)
  142. # Only iterate over IoU values higher than min threshold to speed up the process
  143. for pred_selected_i, target_sorted_i in (sorted_iou > iou_thresholds[0]).nonzero(as_tuple=False):
  144. # pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes
  145. pred_i = preds_idx_to_use[pred_selected_i]
  146. target_i = target_sorted[pred_selected_i, target_sorted_i]
  147. # Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold
  148. is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > iou_thresholds
  149. # Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold
  150. are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])
  151. # Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold
  152. are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free)
  153. is_matching_with_ignore = are_candidates_free & are_candidates_good & targets_ignored[target_i]
  154. if preds_matched[pred_i].any() and is_matching_with_ignore.any():
  155. continue
  156. # For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )
  157. # fill the matching placeholders with True
  158. targets_matched[target_i, are_candidates_good] = True
  159. preds_matched[pred_i, are_candidates_good] = True
  160. preds_to_ignore[pred_i] = torch.logical_or(preds_to_ignore[pred_i], is_matching_with_ignore)
  161. # When all the targets are matched with a prediction for every IoU Threshold, stop.
  162. if targets_matched.all():
  163. break
  164. # Crowd targets can be matched with many predictions.
  165. # Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.
  166. if len(crowd_targets) > 0:
  167. # shape = (n_preds_to_use x n_crowd_targets)
  168. ioa = compute_oks(
  169. preds[preds_idx_to_use],
  170. crowd_targets,
  171. crowd_visibilities,
  172. sigmas,
  173. gt_areas=crowd_targets_areas,
  174. gt_bboxes=crowd_targets_bboxes,
  175. )
  176. # For each prediction, we keep it's highest score with any crowd target (of same class)
  177. # shape = (n_preds_to_use)
  178. best_ioa, _ = ioa.max(1)
  179. # If a prediction has IoA higher than threshold (with any target of same class), then there is a match
  180. # shape = (n_preds_to_use x iou_thresholds)
  181. is_matching_with_crowd = best_ioa.view(-1, 1) > iou_thresholds.view(1, -1)
  182. preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)
  183. # return preds_matched, preds_to_ignore, pred_scores, len(targets)
  184. num_targets = len(targets) - torch.count_nonzero(targets_ignored)
  185. return preds_matched[preds_idx_to_use], preds_to_ignore[preds_idx_to_use], pred_scores[preds_idx_to_use], num_targets.item()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file