Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_io.py 12 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
  1. import contextlib
  2. import os
  3. import sys
  4. import tempfile
  5. import pytest
  6. import torch
  7. import torchvision.io as io
  8. from common_utils import assert_equal, cpu_and_cuda
  9. from torchvision import get_video_backend
  10. try:
  11. import av
  12. # Do a version test too
  13. io.video._check_av_available()
  14. except ImportError:
  15. av = None
  16. VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
  17. def _create_video_frames(num_frames, height, width):
  18. y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width), indexing="ij")
  19. data = []
  20. for i in range(num_frames):
  21. xc = float(i) / num_frames
  22. yc = 1 - float(i) / (2 * num_frames)
  23. d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
  24. data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())
  25. return torch.stack(data, 0)
  26. @contextlib.contextmanager
  27. def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
  28. if lossless:
  29. if video_codec is not None:
  30. raise ValueError("video_codec can't be specified together with lossless")
  31. if options is not None:
  32. raise ValueError("options can't be specified together with lossless")
  33. video_codec = "libx264rgb"
  34. options = {"crf": "0"}
  35. if video_codec is None:
  36. if get_video_backend() == "pyav":
  37. video_codec = "libx264"
  38. else:
  39. # when video_codec is not set, we assume it is libx264rgb which accepts
  40. # RGB pixel formats as input instead of YUV
  41. video_codec = "libx264rgb"
  42. if options is None:
  43. options = {}
  44. data = _create_video_frames(num_frames, height, width)
  45. with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
  46. f.close()
  47. io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
  48. yield f.name, data
  49. os.unlink(f.name)
  50. @pytest.mark.skipif(
  51. get_video_backend() != "pyav" and not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend not available"
  52. )
  53. @pytest.mark.skipif(av is None, reason="PyAV unavailable")
  54. class TestVideo:
  55. # compression adds artifacts, thus we add a tolerance of
  56. # 6 in 0-255 range
  57. TOLERANCE = 6
  58. def test_write_read_video(self):
  59. with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
  60. lv, _, info = io.read_video(f_name)
  61. assert_equal(data, lv)
  62. assert info["video_fps"] == 5
  63. @pytest.mark.skipif(not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend is not chosen")
  64. def test_probe_video_from_file(self):
  65. with temp_video(10, 300, 300, 5) as (f_name, data):
  66. video_info = io._probe_video_from_file(f_name)
  67. assert pytest.approx(2, rel=0.0, abs=0.1) == video_info.video_duration
  68. assert pytest.approx(5, rel=0.0, abs=0.1) == video_info.video_fps
  69. @pytest.mark.skipif(not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend is not chosen")
  70. def test_probe_video_from_memory(self):
  71. with temp_video(10, 300, 300, 5) as (f_name, data):
  72. with open(f_name, "rb") as fp:
  73. filebuffer = fp.read()
  74. video_info = io._probe_video_from_memory(filebuffer)
  75. assert pytest.approx(2, rel=0.0, abs=0.1) == video_info.video_duration
  76. assert pytest.approx(5, rel=0.0, abs=0.1) == video_info.video_fps
  77. def test_read_timestamps(self):
  78. with temp_video(10, 300, 300, 5) as (f_name, data):
  79. pts, _ = io.read_video_timestamps(f_name)
  80. # note: not all formats/codecs provide accurate information for computing the
  81. # timestamps. For the format that we use here, this information is available,
  82. # so we use it as a baseline
  83. with av.open(f_name) as container:
  84. stream = container.streams[0]
  85. pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
  86. num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
  87. expected_pts = [i * pts_step for i in range(num_frames)]
  88. assert pts == expected_pts
  89. @pytest.mark.parametrize("start", range(5))
  90. @pytest.mark.parametrize("offset", range(1, 4))
  91. def test_read_partial_video(self, start, offset):
  92. with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
  93. pts, _ = io.read_video_timestamps(f_name)
  94. lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
  95. s_data = data[start : (start + offset)]
  96. assert len(lv) == offset
  97. assert_equal(s_data, lv)
  98. if get_video_backend() == "pyav":
  99. # for "video_reader" backend, we don't decode the closest early frame
  100. # when the given start pts is not matching any frame pts
  101. lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
  102. assert len(lv) == 4
  103. assert_equal(data[4:8], lv)
  104. @pytest.mark.parametrize("start", range(0, 80, 20))
  105. @pytest.mark.parametrize("offset", range(1, 4))
  106. def test_read_partial_video_bframes(self, start, offset):
  107. # do not use lossless encoding, to test the presence of B-frames
  108. options = {"bframes": "16", "keyint": "10", "min-keyint": "4"}
  109. with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
  110. pts, _ = io.read_video_timestamps(f_name)
  111. lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
  112. s_data = data[start : (start + offset)]
  113. assert len(lv) == offset
  114. assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE)
  115. lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
  116. # TODO fix this
  117. if get_video_backend() == "pyav":
  118. assert len(lv) == 4
  119. assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE)
  120. else:
  121. assert len(lv) == 3
  122. assert_equal(data[5:8], lv, rtol=0.0, atol=self.TOLERANCE)
  123. def test_read_packed_b_frames_divx_file(self):
  124. name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
  125. f_name = os.path.join(VIDEO_DIR, name)
  126. pts, fps = io.read_video_timestamps(f_name)
  127. assert pts == sorted(pts)
  128. assert fps == 30
  129. def test_read_timestamps_from_packet(self):
  130. with temp_video(10, 300, 300, 5, video_codec="mpeg4") as (f_name, data):
  131. pts, _ = io.read_video_timestamps(f_name)
  132. # note: not all formats/codecs provide accurate information for computing the
  133. # timestamps. For the format that we use here, this information is available,
  134. # so we use it as a baseline
  135. with av.open(f_name) as container:
  136. stream = container.streams[0]
  137. # make sure we went through the optimized codepath
  138. assert b"Lavc" in stream.codec_context.extradata
  139. pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
  140. num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
  141. expected_pts = [i * pts_step for i in range(num_frames)]
  142. assert pts == expected_pts
  143. def test_read_video_pts_unit_sec(self):
  144. with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
  145. lv, _, info = io.read_video(f_name, pts_unit="sec")
  146. assert_equal(data, lv)
  147. assert info["video_fps"] == 5
  148. assert info == {"video_fps": 5}
  149. def test_read_timestamps_pts_unit_sec(self):
  150. with temp_video(10, 300, 300, 5) as (f_name, data):
  151. pts, _ = io.read_video_timestamps(f_name, pts_unit="sec")
  152. with av.open(f_name) as container:
  153. stream = container.streams[0]
  154. pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
  155. num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
  156. expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]
  157. assert pts == expected_pts
  158. @pytest.mark.parametrize("start", range(5))
  159. @pytest.mark.parametrize("offset", range(1, 4))
  160. def test_read_partial_video_pts_unit_sec(self, start, offset):
  161. with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
  162. pts, _ = io.read_video_timestamps(f_name, pts_unit="sec")
  163. lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit="sec")
  164. s_data = data[start : (start + offset)]
  165. assert len(lv) == offset
  166. assert_equal(s_data, lv)
  167. with av.open(f_name) as container:
  168. stream = container.streams[0]
  169. lv, _, _ = io.read_video(
  170. f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit="sec"
  171. )
  172. if get_video_backend() == "pyav":
  173. # for "video_reader" backend, we don't decode the closest early frame
  174. # when the given start pts is not matching any frame pts
  175. assert len(lv) == 4
  176. assert_equal(data[4:8], lv)
  177. def test_read_video_corrupted_file(self):
  178. with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
  179. f.write(b"This is not an mpg4 file")
  180. video, audio, info = io.read_video(f.name)
  181. assert isinstance(video, torch.Tensor)
  182. assert isinstance(audio, torch.Tensor)
  183. assert video.numel() == 0
  184. assert audio.numel() == 0
  185. assert info == {}
  186. def test_read_video_timestamps_corrupted_file(self):
  187. with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
  188. f.write(b"This is not an mpg4 file")
  189. video_pts, video_fps = io.read_video_timestamps(f.name)
  190. assert video_pts == []
  191. assert video_fps is None
  192. @pytest.mark.skip(reason="Temporarily disabled due to new pyav")
  193. def test_read_video_partially_corrupted_file(self):
  194. with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
  195. with open(f_name, "r+b") as f:
  196. size = os.path.getsize(f_name)
  197. bytes_to_overwrite = size // 10
  198. # seek to the middle of the file
  199. f.seek(5 * bytes_to_overwrite)
  200. # corrupt 10% of the file from the middle
  201. f.write(b"\xff" * bytes_to_overwrite)
  202. # this exercises the container.decode assertion check
  203. video, audio, info = io.read_video(f.name, pts_unit="sec")
  204. # check that size is not equal to 5, but 3
  205. # TODO fix this
  206. if get_video_backend() == "pyav":
  207. assert len(video) == 3
  208. else:
  209. assert len(video) == 4
  210. # but the valid decoded content is still correct
  211. assert_equal(video[:3], data[:3])
  212. # and the last few frames are wrong
  213. with pytest.raises(AssertionError):
  214. assert_equal(video, data)
  215. @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows")
  216. @pytest.mark.parametrize("device", cpu_and_cuda())
  217. def test_write_video_with_audio(self, device, tmpdir):
  218. f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
  219. video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
  220. out_f_name = os.path.join(tmpdir, "testing.mp4")
  221. io.video.write_video(
  222. out_f_name,
  223. video_tensor.to(device),
  224. round(info["video_fps"]),
  225. video_codec="libx264rgb",
  226. options={"crf": "0"},
  227. audio_array=audio_tensor.to(device),
  228. audio_fps=info["audio_fps"],
  229. audio_codec="aac",
  230. )
  231. out_video_tensor, out_audio_tensor, out_info = io.read_video(out_f_name, pts_unit="sec")
  232. assert info["video_fps"] == out_info["video_fps"]
  233. assert_equal(video_tensor, out_video_tensor)
  234. audio_stream = av.open(f_name).streams.audio[0]
  235. out_audio_stream = av.open(out_f_name).streams.audio[0]
  236. assert info["audio_fps"] == out_info["audio_fps"]
  237. assert audio_stream.rate == out_audio_stream.rate
  238. assert pytest.approx(out_audio_stream.frames, rel=0.0, abs=1) == audio_stream.frames
  239. assert audio_stream.frame_size == out_audio_stream.frame_size
  240. # TODO add tests for audio
  241. if __name__ == "__main__":
  242. pytest.main(__file__)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...