Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

video_utils.py 17 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
  1. import bisect
  2. import math
  3. import warnings
  4. from fractions import Fraction
  5. from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union
  6. import torch
  7. from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps
  8. from .utils import tqdm
  9. T = TypeVar("T")
  10. def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
  11. """convert pts between different time bases
  12. Args:
  13. pts: presentation timestamp, float
  14. timebase_from: original timebase. Fraction
  15. timebase_to: new timebase. Fraction
  16. round_func: rounding function.
  17. """
  18. new_pts = Fraction(pts, 1) * timebase_from / timebase_to
  19. return round_func(new_pts)
  20. def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
  21. """
  22. similar to tensor.unfold, but with the dilation
  23. and specialized for 1d tensors
  24. Returns all consecutive windows of `size` elements, with
  25. `step` between windows. The distance between each element
  26. in a window is given by `dilation`.
  27. """
  28. if tensor.dim() != 1:
  29. raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
  30. o_stride = tensor.stride(0)
  31. numel = tensor.numel()
  32. new_stride = (step * o_stride, dilation * o_stride)
  33. new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
  34. if new_size[0] < 1:
  35. new_size = (0, size)
  36. return torch.as_strided(tensor, new_size, new_stride)
  37. class _VideoTimestampsDataset:
  38. """
  39. Dataset used to parallelize the reading of the timestamps
  40. of a list of videos, given their paths in the filesystem.
  41. Used in VideoClips and defined at top level, so it can be
  42. pickled when forking.
  43. """
  44. def __init__(self, video_paths: List[str]) -> None:
  45. self.video_paths = video_paths
  46. def __len__(self) -> int:
  47. return len(self.video_paths)
  48. def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
  49. return read_video_timestamps(self.video_paths[idx])
  50. def _collate_fn(x: T) -> T:
  51. """
  52. Dummy collate function to be used with _VideoTimestampsDataset
  53. """
  54. return x
  55. class VideoClips:
  56. """
  57. Given a list of video files, computes all consecutive subvideos of size
  58. `clip_length_in_frames`, where the distance between each subvideo in the
  59. same video is defined by `frames_between_clips`.
  60. If `frame_rate` is specified, it will also resample all the videos to have
  61. the same frame rate, and the clips will refer to this frame rate.
  62. Creating this instance the first time is time-consuming, as it needs to
  63. decode all the videos in `video_paths`. It is recommended that you
  64. cache the results after instantiation of the class.
  65. Recreating the clips for different clip lengths is fast, and can be done
  66. with the `compute_clips` method.
  67. Args:
  68. video_paths (List[str]): paths to the video files
  69. clip_length_in_frames (int): size of a clip in number of frames
  70. frames_between_clips (int): step (in frames) between each clip
  71. frame_rate (float, optional): if specified, it will resample the video
  72. so that it has `frame_rate`, and then the clips will be defined
  73. on the resampled video
  74. num_workers (int): how many subprocesses to use for data loading.
  75. 0 means that the data will be loaded in the main process. (default: 0)
  76. output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
  77. """
  78. def __init__(
  79. self,
  80. video_paths: List[str],
  81. clip_length_in_frames: int = 16,
  82. frames_between_clips: int = 1,
  83. frame_rate: Optional[float] = None,
  84. _precomputed_metadata: Optional[Dict[str, Any]] = None,
  85. num_workers: int = 0,
  86. _video_width: int = 0,
  87. _video_height: int = 0,
  88. _video_min_dimension: int = 0,
  89. _video_max_dimension: int = 0,
  90. _audio_samples: int = 0,
  91. _audio_channels: int = 0,
  92. output_format: str = "THWC",
  93. ) -> None:
  94. self.video_paths = video_paths
  95. self.num_workers = num_workers
  96. # these options are not valid for pyav backend
  97. self._video_width = _video_width
  98. self._video_height = _video_height
  99. self._video_min_dimension = _video_min_dimension
  100. self._video_max_dimension = _video_max_dimension
  101. self._audio_samples = _audio_samples
  102. self._audio_channels = _audio_channels
  103. self.output_format = output_format.upper()
  104. if self.output_format not in ("THWC", "TCHW"):
  105. raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
  106. if _precomputed_metadata is None:
  107. self._compute_frame_pts()
  108. else:
  109. self._init_from_metadata(_precomputed_metadata)
  110. self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
  111. def _compute_frame_pts(self) -> None:
  112. self.video_pts = [] # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
  113. self.video_fps: List[float] = [] # len = num_videos
  114. # strategy: use a DataLoader to parallelize read_video_timestamps
  115. # so need to create a dummy dataset first
  116. import torch.utils.data
  117. dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
  118. _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type]
  119. batch_size=16,
  120. num_workers=self.num_workers,
  121. collate_fn=_collate_fn,
  122. )
  123. with tqdm(total=len(dl)) as pbar:
  124. for batch in dl:
  125. pbar.update(1)
  126. batch_pts, batch_fps = list(zip(*batch))
  127. # we need to specify dtype=torch.long because for empty list,
  128. # torch.as_tensor will use torch.float as default dtype. This
  129. # happens when decoding fails and no pts is returned in the list.
  130. batch_pts = [torch.as_tensor(pts, dtype=torch.long) for pts in batch_pts]
  131. self.video_pts.extend(batch_pts)
  132. self.video_fps.extend(batch_fps)
  133. def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
  134. self.video_paths = metadata["video_paths"]
  135. assert len(self.video_paths) == len(metadata["video_pts"])
  136. self.video_pts = metadata["video_pts"]
  137. assert len(self.video_paths) == len(metadata["video_fps"])
  138. self.video_fps = metadata["video_fps"]
  139. @property
  140. def metadata(self) -> Dict[str, Any]:
  141. _metadata = {
  142. "video_paths": self.video_paths,
  143. "video_pts": self.video_pts,
  144. "video_fps": self.video_fps,
  145. }
  146. return _metadata
  147. def subset(self, indices: List[int]) -> "VideoClips":
  148. video_paths = [self.video_paths[i] for i in indices]
  149. video_pts = [self.video_pts[i] for i in indices]
  150. video_fps = [self.video_fps[i] for i in indices]
  151. metadata = {
  152. "video_paths": video_paths,
  153. "video_pts": video_pts,
  154. "video_fps": video_fps,
  155. }
  156. return type(self)(
  157. video_paths,
  158. clip_length_in_frames=self.num_frames,
  159. frames_between_clips=self.step,
  160. frame_rate=self.frame_rate,
  161. _precomputed_metadata=metadata,
  162. num_workers=self.num_workers,
  163. _video_width=self._video_width,
  164. _video_height=self._video_height,
  165. _video_min_dimension=self._video_min_dimension,
  166. _video_max_dimension=self._video_max_dimension,
  167. _audio_samples=self._audio_samples,
  168. _audio_channels=self._audio_channels,
  169. output_format=self.output_format,
  170. )
  171. @staticmethod
  172. def compute_clips_for_video(
  173. video_pts: torch.Tensor, num_frames: int, step: int, fps: Optional[float], frame_rate: Optional[float] = None
  174. ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
  175. if fps is None:
  176. # if for some reason the video doesn't have fps (because doesn't have a video stream)
  177. # set the fps to 1. The value doesn't matter, because video_pts is empty anyway
  178. fps = 1
  179. if frame_rate is None:
  180. frame_rate = fps
  181. total_frames = len(video_pts) * frame_rate / fps
  182. _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
  183. video_pts = video_pts[_idxs]
  184. clips = unfold(video_pts, num_frames, step)
  185. if not clips.numel():
  186. warnings.warn(
  187. "There aren't enough frames in the current video to get a clip for the given clip length and "
  188. "frames between clips. The video (and potentially others) will be skipped."
  189. )
  190. idxs: Union[List[slice], torch.Tensor]
  191. if isinstance(_idxs, slice):
  192. idxs = [_idxs] * len(clips)
  193. else:
  194. idxs = unfold(_idxs, num_frames, step)
  195. return clips, idxs
  196. def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[float] = None) -> None:
  197. """
  198. Compute all consecutive sequences of clips from video_pts.
  199. Always returns clips of size `num_frames`, meaning that the
  200. last few frames in a video can potentially be dropped.
  201. Args:
  202. num_frames (int): number of frames for the clip
  203. step (int): distance between two clips
  204. frame_rate (int, optional): The frame rate
  205. """
  206. self.num_frames = num_frames
  207. self.step = step
  208. self.frame_rate = frame_rate
  209. self.clips = []
  210. self.resampling_idxs = []
  211. for video_pts, fps in zip(self.video_pts, self.video_fps):
  212. clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
  213. self.clips.append(clips)
  214. self.resampling_idxs.append(idxs)
  215. clip_lengths = torch.as_tensor([len(v) for v in self.clips])
  216. self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
  217. def __len__(self) -> int:
  218. return self.num_clips()
  219. def num_videos(self) -> int:
  220. return len(self.video_paths)
  221. def num_clips(self) -> int:
  222. """
  223. Number of subclips that are available in the video list.
  224. """
  225. return self.cumulative_sizes[-1]
  226. def get_clip_location(self, idx: int) -> Tuple[int, int]:
  227. """
  228. Converts a flattened representation of the indices into a video_idx, clip_idx
  229. representation.
  230. """
  231. video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  232. if video_idx == 0:
  233. clip_idx = idx
  234. else:
  235. clip_idx = idx - self.cumulative_sizes[video_idx - 1]
  236. return video_idx, clip_idx
  237. @staticmethod
  238. def _resample_video_idx(num_frames: int, original_fps: float, new_fps: float) -> Union[slice, torch.Tensor]:
  239. step = original_fps / new_fps
  240. if step.is_integer():
  241. # optimization: if step is integer, don't need to perform
  242. # advanced indexing
  243. step = int(step)
  244. return slice(None, None, step)
  245. idxs = torch.arange(num_frames, dtype=torch.float32) * step
  246. idxs = idxs.floor().to(torch.int64)
  247. return idxs
  248. def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
  249. """
  250. Gets a subclip from a list of videos.
  251. Args:
  252. idx (int): index of the subclip. Must be between 0 and num_clips().
  253. Returns:
  254. video (Tensor)
  255. audio (Tensor)
  256. info (Dict)
  257. video_idx (int): index of the video in `video_paths`
  258. """
  259. if idx >= self.num_clips():
  260. raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
  261. video_idx, clip_idx = self.get_clip_location(idx)
  262. video_path = self.video_paths[video_idx]
  263. clip_pts = self.clips[video_idx][clip_idx]
  264. from torchvision import get_video_backend
  265. backend = get_video_backend()
  266. if backend == "pyav":
  267. # check for invalid options
  268. if self._video_width != 0:
  269. raise ValueError("pyav backend doesn't support _video_width != 0")
  270. if self._video_height != 0:
  271. raise ValueError("pyav backend doesn't support _video_height != 0")
  272. if self._video_min_dimension != 0:
  273. raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
  274. if self._video_max_dimension != 0:
  275. raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
  276. if self._audio_samples != 0:
  277. raise ValueError("pyav backend doesn't support _audio_samples != 0")
  278. if backend == "pyav":
  279. start_pts = clip_pts[0].item()
  280. end_pts = clip_pts[-1].item()
  281. video, audio, info = read_video(video_path, start_pts, end_pts)
  282. else:
  283. _info = _probe_video_from_file(video_path)
  284. video_fps = _info.video_fps
  285. audio_fps = None
  286. video_start_pts = cast(int, clip_pts[0].item())
  287. video_end_pts = cast(int, clip_pts[-1].item())
  288. audio_start_pts, audio_end_pts = 0, -1
  289. audio_timebase = Fraction(0, 1)
  290. video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
  291. if _info.has_audio:
  292. audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
  293. audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
  294. audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
  295. audio_fps = _info.audio_sample_rate
  296. video, audio, _ = _read_video_from_file(
  297. video_path,
  298. video_width=self._video_width,
  299. video_height=self._video_height,
  300. video_min_dimension=self._video_min_dimension,
  301. video_max_dimension=self._video_max_dimension,
  302. video_pts_range=(video_start_pts, video_end_pts),
  303. video_timebase=video_timebase,
  304. audio_samples=self._audio_samples,
  305. audio_channels=self._audio_channels,
  306. audio_pts_range=(audio_start_pts, audio_end_pts),
  307. audio_timebase=audio_timebase,
  308. )
  309. info = {"video_fps": video_fps}
  310. if audio_fps is not None:
  311. info["audio_fps"] = audio_fps
  312. if self.frame_rate is not None:
  313. resampling_idx = self.resampling_idxs[video_idx][clip_idx]
  314. if isinstance(resampling_idx, torch.Tensor):
  315. resampling_idx = resampling_idx - resampling_idx[0]
  316. video = video[resampling_idx]
  317. info["video_fps"] = self.frame_rate
  318. assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
  319. if self.output_format == "TCHW":
  320. # [T,H,W,C] --> [T,C,H,W]
  321. video = video.permute(0, 3, 1, 2)
  322. return video, audio, info, video_idx
  323. def __getstate__(self) -> Dict[str, Any]:
  324. video_pts_sizes = [len(v) for v in self.video_pts]
  325. # To be back-compatible, we convert data to dtype torch.long as needed
  326. # because for empty list, in legacy implementation, torch.as_tensor will
  327. # use torch.float as default dtype. This happens when decoding fails and
  328. # no pts is returned in the list.
  329. video_pts = [x.to(torch.int64) for x in self.video_pts]
  330. # video_pts can be an empty list if no frames have been decoded
  331. if video_pts:
  332. video_pts = torch.cat(video_pts) # type: ignore[assignment]
  333. # avoid bug in https://github.com/pytorch/pytorch/issues/32351
  334. # TODO: Revert it once the bug is fixed.
  335. video_pts = video_pts.numpy() # type: ignore[attr-defined]
  336. # make a copy of the fields of self
  337. d = self.__dict__.copy()
  338. d["video_pts_sizes"] = video_pts_sizes
  339. d["video_pts"] = video_pts
  340. # delete the following attributes to reduce the size of dictionary. They
  341. # will be re-computed in "__setstate__()"
  342. del d["clips"]
  343. del d["resampling_idxs"]
  344. del d["cumulative_sizes"]
  345. # for backwards-compatibility
  346. d["_version"] = 2
  347. return d
  348. def __setstate__(self, d: Dict[str, Any]) -> None:
  349. # for backwards-compatibility
  350. if "_version" not in d:
  351. self.__dict__ = d
  352. return
  353. video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64)
  354. video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0)
  355. # don't need this info anymore
  356. del d["video_pts_sizes"]
  357. d["video_pts"] = video_pts
  358. self.__dict__ = d
  359. # recompute attributes "clips", "resampling_idxs" and other derivative ones
  360. self.compute_clips(self.num_frames, self.step, self.frame_rate)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...