Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_transforms_v2_utils.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. import PIL.Image
  2. import pytest
  3. import torch
  4. import torchvision.transforms.v2._utils
  5. from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image
  6. from torchvision import tv_tensors
  7. from torchvision.transforms.v2._utils import has_all, has_any
  8. from torchvision.transforms.v2.functional import to_pil_image
  9. IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
  10. BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
  11. MASK = make_detection_masks(DEFAULT_SIZE)
  12. @pytest.mark.parametrize(
  13. ("sample", "types", "expected"),
  14. [
  15. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
  16. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
  17. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
  18. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
  19. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
  20. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
  21. ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
  22. ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
  23. ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  24. (
  25. (IMAGE, BOUNDING_BOX, MASK),
  26. (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
  27. True,
  28. ),
  29. ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  30. ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
  31. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
  32. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
  33. ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
  34. (
  35. (torch.Tensor(IMAGE),),
  36. (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
  37. True,
  38. ),
  39. (
  40. (to_pil_image(IMAGE),),
  41. (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
  42. True,
  43. ),
  44. ],
  45. )
  46. def test_has_any(sample, types, expected):
  47. assert has_any(sample, *types) is expected
  48. @pytest.mark.parametrize(
  49. ("sample", "types", "expected"),
  50. [
  51. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
  52. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
  53. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
  54. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
  55. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
  56. ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
  57. (
  58. (IMAGE, BOUNDING_BOX, MASK),
  59. (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
  60. True,
  61. ),
  62. ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
  63. ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
  64. ((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  65. (
  66. (IMAGE, BOUNDING_BOX, MASK),
  67. (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
  68. True,
  69. ),
  70. ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  71. ((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  72. ((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
  73. (
  74. (IMAGE, BOUNDING_BOX, MASK),
  75. (lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
  76. True,
  77. ),
  78. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
  79. ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
  80. ],
  81. )
  82. def test_has_all(sample, types, expected):
  83. assert has_all(sample, *types) is expected
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...