|
@@ -1,12 +1,17 @@
|
|
from typing import List, Optional, Tuple
|
|
from typing import List, Optional, Tuple
|
|
import cv2
|
|
import cv2
|
|
|
|
+import PIL
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
+from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
|
|
+
|
|
|
|
+logger = get_logger(__name__)
|
|
|
|
+
|
|
__all__ = ["load_video", "save_video", "includes_video_extension", "show_video_from_disk", "show_video_from_frames"]
|
|
__all__ = ["load_video", "save_video", "includes_video_extension", "show_video_from_disk", "show_video_from_frames"]
|
|
|
|
|
|
-VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".wmv", ".flv")
|
|
|
|
|
|
+VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".wmv", ".flv", ".gif")
|
|
|
|
|
|
|
|
|
|
def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
|
|
def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
|
|
@@ -57,7 +62,37 @@ def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) ->
|
|
|
|
|
|
|
|
|
|
def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
|
|
def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
|
|
- """Save a video locally.
|
|
|
|
|
|
+ """Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.
|
|
|
|
+
|
|
|
|
+ :param output_path: Where the video will be saved
|
|
|
|
+ :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
|
|
|
|
+ :param fps: Frames per second
|
|
|
|
+ """
|
|
|
|
+ if not includes_video_extension(output_path):
|
|
|
|
+ logger.info(f'Output path "{output_path}" does not have a video extension, and therefore will be saved as {output_path}.mp4')
|
|
|
|
+ output_path += ".mp4"
|
|
|
|
+
|
|
|
|
+ if check_is_gif(output_path):
|
|
|
|
+ save_gif(output_path, frames, fps)
|
|
|
|
+ else:
|
|
|
|
+ save_mp4(output_path, frames, fps)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None:
|
|
|
|
+ """Save a video locally in .gif format.
|
|
|
|
+
|
|
|
|
+ :param output_path: Where the video will be saved
|
|
|
|
+ :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
|
|
|
|
+ :param fps: Frames per second
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ frames_pil = [PIL.Image.fromarray(frame) for frame in frames]
|
|
|
|
+
|
|
|
|
+ frames_pil[0].save(output_path, save_all=True, append_images=frames_pil[1:], duration=int(1000 / fps), loop=0)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None:
|
|
|
|
+ """Save a video locally in .mp4 format.
|
|
|
|
|
|
:param output_path: Where the video will be saved
|
|
:param output_path: Where the video will be saved
|
|
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
|
|
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
|
|
@@ -151,3 +186,7 @@ def includes_video_extension(file_path: str) -> bool:
|
|
:return: True if the file includes a video extension.
|
|
:return: True if the file includes a video extension.
|
|
"""
|
|
"""
|
|
return isinstance(file_path, str) and file_path.lower().endswith(VIDEO_EXTENSIONS)
|
|
return isinstance(file_path, str) and file_path.lower().endswith(VIDEO_EXTENSIONS)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def check_is_gif(file_path: str) -> bool:
|
|
|
|
+ return file_path.lower().endswith(".gif")
|