Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#581 Bug/sg 512 shuffle bugfix in recipe datalaoders

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-512_shuffle_bugfix_in_recipe_datalaoders
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  1. import warnings
  2. from typing import Tuple
  3. import numpy as np
  4. import torch
  5. from super_gradients.training.utils.bbox_formats.bbox_format import (
  6. BoundingBoxFormat,
  7. )
  8. __all__ = ["xyxy_to_cxcywh", "xyxy_to_cxcywh_inplace", "cxcywh_to_xyxy_inplace", "cxcywh_to_xyxy", "CXCYWHCoordinateFormat"]
  9. def xyxy_to_cxcywh(bboxes, image_shape: Tuple[int, int]):
  10. """
  11. Transforms bboxes from xyxy format to CX-CY-W-H format
  12. :param bboxes: BBoxes of shape (..., 4) in XYXY format
  13. :return: BBoxes of shape (..., 4) in CX-CY-W-H format
  14. """
  15. x1, y1, x2, y2 = bboxes[..., 0], bboxes[..., 1], bboxes[..., 2], bboxes[..., 3]
  16. w = x2 - x1
  17. h = y2 - y1
  18. cx = x1 + 0.5 * w
  19. cy = y1 + 0.5 * h
  20. if torch.jit.is_scripting():
  21. return torch.stack([cx, cy, w, h], dim=-1)
  22. else:
  23. if torch.is_tensor(bboxes):
  24. return torch.stack([cx, cy, w, h], dim=-1)
  25. elif isinstance(bboxes, np.ndarray):
  26. return np.stack([cx, cy, w, h], axis=-1)
  27. else:
  28. raise RuntimeError(f"Only Torch tensor or Numpy array is supported. Received bboxes of type {str(type(bboxes))}")
  29. def cxcywh_to_xyxy(bboxes, image_shape: Tuple[int, int]):
  30. """
  31. Transforms bboxes from CX-CY-W-H format to XYXY format
  32. :param bboxes: BBoxes of shape (..., 4) in CX-CY-W-H format
  33. :return: BBoxes of shape (..., 4) in XYXY format
  34. """
  35. cx, cy, w, h = bboxes[..., 0], bboxes[..., 1], bboxes[..., 2], bboxes[..., 3]
  36. x1 = cx - 0.5 * w
  37. y1 = cy - 0.5 * h
  38. x2 = x1 + w
  39. y2 = y1 + h
  40. if torch.jit.is_scripting():
  41. return torch.stack([x1, y1, x2, y2], dim=-1)
  42. else:
  43. if torch.is_tensor(bboxes):
  44. return torch.stack([x1, y1, x2, y2], dim=-1)
  45. if isinstance(bboxes, np.ndarray):
  46. return np.stack([x1, y1, x2, y2], axis=-1)
  47. else:
  48. raise RuntimeError(f"Only Torch tensor or Numpy array is supported. Received bboxes of type {str(type(bboxes))}")
  49. def cxcywh_to_xyxy_inplace(bboxes, image_shape: Tuple[int, int]):
  50. """
  51. Not that bboxes dtype is preserved, and it may lead to unwanted rounding errors when computing a center of bbox.
  52. :param bboxes: BBoxes of shape (..., 4) in CX-CY-W-H format
  53. :return: BBoxes of shape (..., 4) in XYXY format
  54. """
  55. if not torch.jit.is_scripting():
  56. if torch.is_tensor(bboxes) and not torch.is_floating_point(bboxes):
  57. warnings.warn(
  58. f"Detected non floating-point ({bboxes.dtype}) input to cxcywh_to_xyxy_inplace function. "
  59. f"This may cause rounding errors and lose of precision. You may want to convert your array to floating-point precision first."
  60. )
  61. if isinstance(bboxes, np.ndarray) and not np.issubdtype(bboxes.dtype, np.floating):
  62. warnings.warn(
  63. f"Detected non floating-point input ({bboxes.dtype}) to cxcywh_to_xyxy_inplace function. "
  64. f"This may cause rounding errors and lose of precision. You may want to convert your array to floating-point precision first."
  65. )
  66. bboxes[..., 0:2] -= bboxes[..., 2:4] * 0.5 # cxcy -> x1y1
  67. bboxes[..., 2:4] += bboxes[..., 0:2] # x1y1 + wh -> x2y2
  68. return bboxes
  69. def xyxy_to_cxcywh_inplace(bboxes, image_shape: Tuple[int, int]):
  70. """
  71. Transforms bboxes from xyxy format to CX-CY-W-H format. This function operates in-place.
  72. Not that bboxes dtype is preserved, and it may lead to unwanted rounding errors when computing a center of bbox.
  73. :param bboxes: BBoxes of shape (..., 4) in XYXY format
  74. :return: BBoxes of shape (..., 4) in CX-CY-W-H format
  75. """
  76. if not torch.jit.is_scripting():
  77. if torch.is_tensor(bboxes) and not torch.is_floating_point(bboxes):
  78. warnings.warn(
  79. f"Detected non floating-point ({bboxes.dtype}) input to xyxy_to_cxcywh_inplace function. This may cause rounding errors and lose of precision. "
  80. "You may want to convert your array to floating-point precision first."
  81. )
  82. if isinstance(bboxes, np.ndarray) and not isinstance(bboxes.dtype, np.floating):
  83. warnings.warn(
  84. f"Detected non floating-point input ({bboxes.dtype}) to xyxy_to_cxcywh_inplace function. This may cause rounding errors and lose of precision. "
  85. "You may want to convert your array to floating-point precision first."
  86. )
  87. bboxes[..., 2:4] -= bboxes[..., 0:2] # x2y2 - x1y2 -> wh
  88. bboxes[..., 0:2] += bboxes[..., 2:4] * 0.5 # cxcywh
  89. return bboxes
  90. class CXCYWHCoordinateFormat(BoundingBoxFormat):
  91. def __init__(self):
  92. self.format = "cxcywh"
  93. self.normalized = False
  94. def get_to_xyxy(self, inplace: bool):
  95. if inplace:
  96. return cxcywh_to_xyxy_inplace
  97. else:
  98. return cxcywh_to_xyxy
  99. def get_from_xyxy(self, inplace: bool):
  100. if inplace:
  101. return xyxy_to_cxcywh_inplace
  102. else:
  103. return xyxy_to_cxcywh
Discard
Tip!

Press p or to see the previous file or, n or to see the next file