Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#853 Add gif support on predict

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-818-add_gif_support
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  1. import torch
  2. from super_gradients.common.object_names import Models
  3. from super_gradients.training import models
  4. # Note that currently only YoloX and PPYoloE are supported.
  5. model = models.get(Models.YOLOX_N, pretrained_weights="coco")
  6. # We want to use cuda if available to speed up inference.
  7. model = model.to("cuda" if torch.cuda.is_available() else "cpu")
  8. predictions = model.predict(
  9. "../../../../documentation/source/images/examples/pose_elephant_flip.gif",
  10. )
  11. predictions.show()
  12. predictions.save("pose_elephant_flip_prediction.gif")
  13. predictions.save("pose_elephant_flip_prediction.mp4") # Can also be saved as a mp4 video.
Discard
@@ -19,3 +19,4 @@ with open(video_path, mode="wb") as f:
 predictions = model.predict(video_path)
 predictions = model.predict(video_path)
 predictions.show()
 predictions.show()
 predictions.save("pose_elephant_flip_prediction.mp4")
 predictions.save("pose_elephant_flip_prediction.mp4")
+predictions.save("pose_elephant_flip_prediction.gif")  # Can also be saved as a gif.
Discard
@@ -1,12 +1,17 @@
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
 import cv2
 import cv2
+import PIL
 
 
 import numpy as np
 import numpy as np
 
 
 
 
+from super_gradients.common.abstractions.abstract_logger import get_logger
+
+logger = get_logger(__name__)
+
 __all__ = ["load_video", "save_video", "includes_video_extension", "show_video_from_disk", "show_video_from_frames"]
 __all__ = ["load_video", "save_video", "includes_video_extension", "show_video_from_disk", "show_video_from_frames"]
 
 
-VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".wmv", ".flv")
+VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".wmv", ".flv", ".gif")
 
 
 
 
 def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
 def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
@@ -57,7 +62,37 @@ def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) ->
 
 
 
 
 def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
 def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
-    """Save a video locally.
+    """Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.
+
+    :param output_path: Where the video will be saved
+    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
+    :param fps:         Frames per second
+    """
+    if not includes_video_extension(output_path):
+        logger.info(f'Output path "{output_path}" does not have a video extension, and therefore will be saved as {output_path}.mp4')
+        output_path += ".mp4"
+
+    if check_is_gif(output_path):
+        save_gif(output_path, frames, fps)
+    else:
+        save_mp4(output_path, frames, fps)
+
+
+def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None:
+    """Save a video locally in .gif format.
+
+    :param output_path: Where the video will be saved
+    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
+    :param fps:         Frames per second
+    """
+
+    frames_pil = [PIL.Image.fromarray(frame) for frame in frames]
+
+    frames_pil[0].save(output_path, save_all=True, append_images=frames_pil[1:], duration=int(1000 / fps), loop=0)
+
+
+def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None:
+    """Save a video locally in .mp4 format.
 
 
     :param output_path: Where the video will be saved
     :param output_path: Where the video will be saved
     :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
     :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
@@ -151,3 +186,7 @@ def includes_video_extension(file_path: str) -> bool:
     :return:            True if the file includes a video extension.
     :return:            True if the file includes a video extension.
     """
     """
     return isinstance(file_path, str) and file_path.lower().endswith(VIDEO_EXTENSIONS)
     return isinstance(file_path, str) and file_path.lower().endswith(VIDEO_EXTENSIONS)
+
+
+def check_is_gif(file_path: str) -> bool:
+    return file_path.lower().endswith(".gif")
Discard