Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
  1. # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
  2. import torch
  3. import torch.nn as nn
  4. from . import LOGGER
  5. from .checks import check_version
  6. from .metrics import bbox_iou, probiou
  7. from .ops import xywhr2xyxyxyxy
  8. TORCH_1_10 = check_version(torch.__version__, "1.10.0")
  9. class TaskAlignedAssigner(nn.Module):
  10. """
  11. A task-aligned assigner for object detection.
  12. This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
  13. classification and localization information.
  14. Attributes:
  15. topk (int): The number of top candidates to consider.
  16. num_classes (int): The number of object classes.
  17. bg_idx (int): Background class index.
  18. alpha (float): The alpha parameter for the classification component of the task-aligned metric.
  19. beta (float): The beta parameter for the localization component of the task-aligned metric.
  20. eps (float): A small value to prevent division by zero.
  21. """
  22. def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
  23. """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
  24. super().__init__()
  25. self.topk = topk
  26. self.num_classes = num_classes
  27. self.bg_idx = num_classes
  28. self.alpha = alpha
  29. self.beta = beta
  30. self.eps = eps
  31. @torch.no_grad()
  32. def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
  33. """
  34. Compute the task-aligned assignment.
  35. Args:
  36. pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
  37. pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
  38. anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
  39. gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
  40. gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
  41. mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
  42. Returns:
  43. target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
  44. target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
  45. target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
  46. fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
  47. target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
  48. References:
  49. https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
  50. """
  51. self.bs = pd_scores.shape[0]
  52. self.n_max_boxes = gt_bboxes.shape[1]
  53. device = gt_bboxes.device
  54. if self.n_max_boxes == 0:
  55. return (
  56. torch.full_like(pd_scores[..., 0], self.bg_idx),
  57. torch.zeros_like(pd_bboxes),
  58. torch.zeros_like(pd_scores),
  59. torch.zeros_like(pd_scores[..., 0]),
  60. torch.zeros_like(pd_scores[..., 0]),
  61. )
  62. try:
  63. return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
  64. except torch.cuda.OutOfMemoryError:
  65. # Move tensors to CPU, compute, then move back to original device
  66. LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
  67. cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
  68. result = self._forward(*cpu_tensors)
  69. return tuple(t.to(device) for t in result)
  70. def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
  71. """
  72. Compute the task-aligned assignment.
  73. Args:
  74. pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
  75. pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
  76. anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
  77. gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
  78. gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
  79. mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
  80. Returns:
  81. target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
  82. target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
  83. target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
  84. fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
  85. target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
  86. """
  87. mask_pos, align_metric, overlaps = self.get_pos_mask(
  88. pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
  89. )
  90. target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
  91. # Assigned target
  92. target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
  93. # Normalize
  94. align_metric *= mask_pos
  95. pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
  96. pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
  97. norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
  98. target_scores = target_scores * norm_align_metric
  99. return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
  100. def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
  101. """
  102. Get positive mask for each ground truth box.
  103. Args:
  104. pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
  105. pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
  106. gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
  107. gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
  108. anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
  109. mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
  110. Returns:
  111. mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
  112. align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
  113. overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
  114. """
  115. mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
  116. # Get anchor_align metric, (b, max_num_obj, h*w)
  117. align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
  118. # Get topk_metric mask, (b, max_num_obj, h*w)
  119. mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
  120. # Merge all mask to a final mask, (b, max_num_obj, h*w)
  121. mask_pos = mask_topk * mask_in_gts * mask_gt
  122. return mask_pos, align_metric, overlaps
  123. def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
  124. """
  125. Compute alignment metric given predicted and ground truth bounding boxes.
  126. Args:
  127. pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
  128. pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
  129. gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
  130. gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
  131. mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
  132. Returns:
  133. align_metric (torch.Tensor): Alignment metric combining classification and localization.
  134. overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
  135. """
  136. na = pd_bboxes.shape[-2]
  137. mask_gt = mask_gt.bool() # b, max_num_obj, h*w
  138. overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
  139. bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
  140. ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
  141. ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
  142. ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
  143. # Get the scores of each grid for each gt cls
  144. bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
  145. # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
  146. pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
  147. gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
  148. overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
  149. align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
  150. return align_metric, overlaps
  151. def iou_calculation(self, gt_bboxes, pd_bboxes):
  152. """
  153. Calculate IoU for horizontal bounding boxes.
  154. Args:
  155. gt_bboxes (torch.Tensor): Ground truth boxes.
  156. pd_bboxes (torch.Tensor): Predicted boxes.
  157. Returns:
  158. (torch.Tensor): IoU values between each pair of boxes.
  159. """
  160. return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
  161. def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
  162. """
  163. Select the top-k candidates based on the given metrics.
  164. Args:
  165. metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
  166. max_num_obj is the maximum number of objects, and h*w represents the
  167. total number of anchor points.
  168. largest (bool): If True, select the largest values; otherwise, select the smallest values.
  169. topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
  170. topk is the number of top candidates to consider. If not provided,
  171. the top-k values are automatically computed based on the given metrics.
  172. Returns:
  173. (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
  174. """
  175. # (b, max_num_obj, topk)
  176. topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
  177. if topk_mask is None:
  178. topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
  179. # (b, max_num_obj, topk)
  180. topk_idxs.masked_fill_(~topk_mask, 0)
  181. # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
  182. count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
  183. ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
  184. for k in range(self.topk):
  185. # Expand topk_idxs for each value of k and add 1 at the specified positions
  186. count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
  187. # Filter invalid bboxes
  188. count_tensor.masked_fill_(count_tensor > 1, 0)
  189. return count_tensor.to(metrics.dtype)
  190. def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
  191. """
  192. Compute target labels, target bounding boxes, and target scores for the positive anchor points.
  193. Args:
  194. gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
  195. batch size and max_num_obj is the maximum number of objects.
  196. gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
  197. target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
  198. anchor points, with shape (b, h*w), where h*w is the total
  199. number of anchor points.
  200. fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
  201. (foreground) anchor points.
  202. Returns:
  203. target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
  204. target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
  205. anchor points.
  206. target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
  207. anchor points.
  208. """
  209. # Assigned target labels, (b, 1)
  210. batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
  211. target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
  212. target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
  213. # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
  214. target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
  215. # Assigned target scores
  216. target_labels.clamp_(0)
  217. # 10x faster than F.one_hot()
  218. target_scores = torch.zeros(
  219. (target_labels.shape[0], target_labels.shape[1], self.num_classes),
  220. dtype=torch.int64,
  221. device=target_labels.device,
  222. ) # (b, h*w, 80)
  223. target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
  224. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
  225. target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
  226. return target_labels, target_bboxes, target_scores
  227. @staticmethod
  228. def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
  229. """
  230. Select positive anchor centers within ground truth bounding boxes.
  231. Args:
  232. xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
  233. gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
  234. eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
  235. Returns:
  236. (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
  237. Note:
  238. b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
  239. Bounding box format: [x_min, y_min, x_max, y_max].
  240. """
  241. n_anchors = xy_centers.shape[0]
  242. bs, n_boxes, _ = gt_bboxes.shape
  243. lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
  244. bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
  245. return bbox_deltas.amin(3).gt_(eps)
  246. @staticmethod
  247. def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
  248. """
  249. Select anchor boxes with highest IoU when assigned to multiple ground truths.
  250. Args:
  251. mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
  252. overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
  253. n_max_boxes (int): Maximum number of ground truth boxes.
  254. Returns:
  255. target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
  256. fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
  257. mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
  258. """
  259. # Convert (b, n_max_boxes, h*w) -> (b, h*w)
  260. fg_mask = mask_pos.sum(-2)
  261. if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
  262. mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
  263. max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
  264. is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
  265. is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
  266. mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
  267. fg_mask = mask_pos.sum(-2)
  268. # Find each grid serve which gt(index)
  269. target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
  270. return target_gt_idx, fg_mask, mask_pos
  271. class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
  272. """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
  273. def iou_calculation(self, gt_bboxes, pd_bboxes):
  274. """Calculate IoU for rotated bounding boxes."""
  275. return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
  276. @staticmethod
  277. def select_candidates_in_gts(xy_centers, gt_bboxes):
  278. """
  279. Select the positive anchor center in gt for rotated bounding boxes.
  280. Args:
  281. xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
  282. gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
  283. Returns:
  284. (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
  285. """
  286. # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
  287. corners = xywhr2xyxyxyxy(gt_bboxes)
  288. # (b, n_boxes, 1, 2)
  289. a, b, _, d = corners.split(1, dim=-2)
  290. ab = b - a
  291. ad = d - a
  292. # (b, n_boxes, h*w, 2)
  293. ap = xy_centers - a
  294. norm_ab = (ab * ab).sum(dim=-1)
  295. norm_ad = (ad * ad).sum(dim=-1)
  296. ap_dot_ab = (ap * ab).sum(dim=-1)
  297. ap_dot_ad = (ap * ad).sum(dim=-1)
  298. return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
  299. def make_anchors(feats, strides, grid_cell_offset=0.5):
  300. """Generate anchors from features."""
  301. anchor_points, stride_tensor = [], []
  302. assert feats is not None
  303. dtype, device = feats[0].dtype, feats[0].device
  304. for i, stride in enumerate(strides):
  305. h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
  306. sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
  307. sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
  308. sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
  309. anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
  310. stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
  311. return torch.cat(anchor_points), torch.cat(stride_tensor)
  312. def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
  313. """Transform distance(ltrb) to box(xywh or xyxy)."""
  314. lt, rb = distance.chunk(2, dim)
  315. x1y1 = anchor_points - lt
  316. x2y2 = anchor_points + rb
  317. if xywh:
  318. c_xy = (x1y1 + x2y2) / 2
  319. wh = x2y2 - x1y1
  320. return torch.cat((c_xy, wh), dim) # xywh bbox
  321. return torch.cat((x1y1, x2y2), dim) # xyxy bbox
  322. def bbox2dist(anchor_points, bbox, reg_max):
  323. """Transform bbox(xyxy) to dist(ltrb)."""
  324. x1y1, x2y2 = bbox.chunk(2, -1)
  325. return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
  326. def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
  327. """
  328. Decode predicted rotated bounding box coordinates from anchor points and distribution.
  329. Args:
  330. pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
  331. pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
  332. anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
  333. dim (int, optional): Dimension along which to split. Defaults to -1.
  334. Returns:
  335. (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
  336. """
  337. lt, rb = pred_dist.split(2, dim=dim)
  338. cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
  339. # (bs, h*w, 1)
  340. xf, yf = ((rb - lt) / 2).split(1, dim=dim)
  341. x, y = xf * cos - yf * sin, xf * sin + yf * cos
  342. xy = torch.cat([x, y], dim=dim) + anchor_points
  343. return torch.cat([xy, lt + rb], dim=dim)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file