Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

smoke_test.py 5.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  1. """Run smoke tests"""
  2. import os
  3. import sys
  4. import sysconfig
  5. from pathlib import Path
  6. import torch
  7. import torchvision
  8. from torchvision.io import decode_avif, decode_heic, decode_image, decode_jpeg, read_file
  9. from torchvision.models import resnet50, ResNet50_Weights
  10. SCRIPT_DIR = Path(__file__).parent
  11. def smoke_test_torchvision() -> None:
  12. print(
  13. "Is torchvision usable?",
  14. all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
  15. )
  16. def smoke_test_torchvision_read_decode() -> None:
  17. img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
  18. if img_jpg.shape != (3, 606, 517):
  19. raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
  20. img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
  21. if img_png.shape != (4, 471, 354):
  22. raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
  23. img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
  24. if img_webp.shape != (3, 100, 100):
  25. raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")
  26. if sys.platform == "linux":
  27. pass
  28. # TODO: Fix/uncomment below (the TODO below is mostly accurate but we're
  29. # still observing some failures on some CUDA jobs. Most are working.)
  30. # if torch.cuda.is_available():
  31. # # TODO: For whatever reason this only passes on the runners that
  32. # # support CUDA.
  33. # # Strangely, on the CPU runners where this fails, the AVIF/HEIC
  34. # # tests (ran with pytest) are passing. This is likely related to a
  35. # # libcxx symbol thing, and the proper libstdc++.so get loaded only
  36. # # with pytest? Ugh.
  37. # img_avif = decode_avif(read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif")))
  38. # if img_avif.shape != (3, 100, 100):
  39. # raise RuntimeError(f"Unexpected shape of img_avif: {img_avif.shape}")
  40. # img_heic = decode_heic(
  41. # read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic"))
  42. # )
  43. # if img_heic.shape != (3, 100, 100):
  44. # raise RuntimeError(f"Unexpected shape of img_heic: {img_heic.shape}")
  45. else:
  46. try:
  47. decode_avif(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif"))
  48. except RuntimeError as e:
  49. assert "torchvision-extra-decoders" in str(e)
  50. try:
  51. decode_heic(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic"))
  52. except RuntimeError as e:
  53. assert "torchvision-extra-decoders" in str(e)
  54. def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
  55. img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
  56. img_jpg = decode_jpeg(img_jpg_data, device=device)
  57. if img_jpg.shape != (3, 606, 517):
  58. raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
  59. def smoke_test_compile() -> None:
  60. try:
  61. model = resnet50().cuda()
  62. model = torch.compile(model)
  63. x = torch.randn(1, 3, 224, 224, device="cuda")
  64. out = model(x)
  65. print(f"torch.compile model output: {out.shape}")
  66. except RuntimeError:
  67. if sys.platform == "win32":
  68. print("Successfully caught torch.compile RuntimeError on win")
  69. else:
  70. raise
  71. def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
  72. img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
  73. # Step 1: Initialize model with the best available weights
  74. weights = ResNet50_Weights.DEFAULT
  75. model = resnet50(weights=weights, progress=False).to(device)
  76. model.eval()
  77. # Step 2: Initialize the inference transforms
  78. preprocess = weights.transforms(antialias=True)
  79. # Step 3: Apply inference preprocessing transforms
  80. batch = preprocess(img).unsqueeze(0)
  81. # Step 4: Use the model and print the predicted category
  82. prediction = model(batch).squeeze(0).softmax(0)
  83. class_id = prediction.argmax().item()
  84. score = prediction[class_id].item()
  85. category_name = weights.meta["categories"][class_id]
  86. expected_category = "German shepherd"
  87. print(f"{category_name} ({device}): {100 * score:.1f}%")
  88. if category_name != expected_category:
  89. raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
  90. def main() -> None:
  91. print(f"torchvision: {torchvision.__version__}")
  92. print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
  93. print(f"{torch.ops.image._jpeg_version() = }")
  94. if not torch.ops.image._is_compiled_against_turbo():
  95. msg = "Torchvision wasn't compiled against libjpeg-turbo"
  96. if os.getenv("IS_M1_CONDA_BUILD_JOB") == "1":
  97. # When building the conda package on M1, it's difficult to enforce
  98. # that we build against turbo due to interactions with the libwebp
  99. # package. So we just accept it, instead of raising an error.
  100. print(msg)
  101. else:
  102. raise ValueError(msg)
  103. smoke_test_torchvision()
  104. smoke_test_torchvision_read_decode()
  105. smoke_test_torchvision_resnet50_classify()
  106. smoke_test_torchvision_decode_jpeg()
  107. if torch.cuda.is_available():
  108. smoke_test_torchvision_decode_jpeg("cuda")
  109. smoke_test_torchvision_resnet50_classify("cuda")
  110. # torch.compile is not supported on Python 3.14+ and Python built with GIL disabled
  111. if sys.version_info < (3, 14, 0) and not sysconfig.get_config_var("Py_GIL_DISABLED"):
  112. smoke_test_compile()
  113. if torch.backends.mps.is_available():
  114. smoke_test_torchvision_resnet50_classify("mps")
  115. if __name__ == "__main__":
  116. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...