Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#829 Feature/sg 747 support predict video full pipeline master

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-747-support_predict_video_full_pipeline_master
22 changed files with 906 additions and 227 deletions
  1. BIN
      documentation/source/images/examples/countryside.jpg
  2. BIN
      documentation/source/images/examples/street_busy.jpg
  3. BIN
      documentation/source/images/examples/street_vehicles.jpg
  4. 5
    4
      src/super_gradients/examples/predict/detection_predict.py
  5. 9
    0
      src/super_gradients/examples/predict/detection_predict_image_folder.py
  6. 6
    0
      src/super_gradients/examples/predict/detection_predict_streaming.py
  7. 9
    0
      src/super_gradients/examples/predict/detection_predict_video.py
  8. 1
    1
      src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py
  9. 79
    3
      src/super_gradients/training/models/detection_models/customizable_detector.py
  10. 47
    5
      src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py
  11. 48
    6
      src/super_gradients/training/models/detection_models/yolo_base.py
  12. 3
    2
      src/super_gradients/training/models/model_factory.py
  13. 80
    51
      src/super_gradients/training/models/prediction_results.py
  14. 5
    2
      src/super_gradients/training/models/sg_module.py
  15. 155
    39
      src/super_gradients/training/pipelines/pipelines.py
  16. 108
    21
      src/super_gradients/training/transforms/processing.py
  17. 0
    87
      src/super_gradients/training/utils/load_image.py
  18. 0
    0
      src/super_gradients/training/utils/media/__init__.py
  19. 158
    0
      src/super_gradients/training/utils/media/image.py
  20. 117
    0
      src/super_gradients/training/utils/media/stream.py
  21. 59
    1
      src/super_gradients/training/utils/media/video.py
  22. 17
    5
      src/super_gradients/training/utils/utils.py
Discard
Discard
Discard
@@ -5,9 +5,10 @@ from super_gradients.training import models
 model = models.get(Models.PP_YOLOE_S, pretrained_weights="coco")
 model = models.get(Models.PP_YOLOE_S, pretrained_weights="coco")
 
 
 IMAGES = [
 IMAGES = [
-    "https://miro.medium.com/v2/resize:fit:500/0*w1s81z-Q72obhE_z",
-    "https://s.hs-data.com/bilder/spieler/gross/128069.jpg",
-    "https://datasets-server.huggingface.co/assets/Chris1/cityscapes/--/Chris1--cityscapes/train/28/image/image.jpg",
+    "../../../../documentation/source/images/examples/countryside.jpg",
+    "../../../../documentation/source/images/examples/street_busy.jpg",
+    "https://cdn-attachments.timesofmalta.com/cc1eceadde40d2940bc5dd20692901371622153217-1301777007-4d978a6f-620x348.jpg",
 ]
 ]
-prediction = model.predict(IMAGES, iou=0.65, conf=0.5)
+
+prediction = model.predict(IMAGES)
 prediction.show()
 prediction.show()
Discard
1
2
3
4
5
6
7
8
9
  1. from super_gradients.common.object_names import Models
  2. from super_gradients.training import models
  3. # Note that currently only YoloX and PPYoloE are supported.
  4. model = models.get(Models.YOLOX_N, pretrained_weights="coco")
  5. image_folder_path = "../../../../documentation/source/images/examples"
  6. predictions = model.predict(image_folder_path)
  7. predictions.show()
Discard
1
2
3
4
5
6
  1. from super_gradients.common.object_names import Models
  2. from super_gradients.training import models
  3. # Note that currently only YoloX and PPYoloE are supported.
  4. model = models.get(Models.YOLOX_N, pretrained_weights="coco")
  5. model.predict_webcam()
Discard
1
2
3
4
5
6
7
8
9
  1. from super_gradients.common.object_names import Models
  2. from super_gradients.training import models
  3. # Note that currently only YoloX and PPYoloE are supported.
  4. model = models.get(Models.YOLOX_N, pretrained_weights="coco")
  5. video_path = "<path/to/your/video>"
  6. predictions = model.predict(video_path)
  7. predictions.show()
Discard
@@ -5,7 +5,7 @@ import numpy as np
 from typing import List, Optional
 from typing import List, Optional
 
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
-from super_gradients.training.utils.load_image import is_image
+from super_gradients.training.utils.media.image import is_image
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
 from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
 from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter
 from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_NORMALIZED_CXCYWH
 from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_NORMALIZED_CXCYWH
Discard
@@ -5,9 +5,7 @@ A base for a detection network built according to the following scheme:
  * each module accepts in_channels and other parameters
  * each module accepts in_channels and other parameters
  * each module defines out_channels property on construction
  * each module defines out_channels property on construction
 """
 """
-
-
-from typing import Union, Optional
+from typing import Union, Optional, List
 
 
 from torch import nn
 from torch import nn
 from omegaconf import DictConfig
 from omegaconf import DictConfig
@@ -15,6 +13,11 @@ from omegaconf import DictConfig
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.sg_module import SgModule
 import super_gradients.common.factories.detection_modules_factory as det_factory
 import super_gradients.common.factories.detection_modules_factory as det_factory
+from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
+from super_gradients.training.pipelines.pipelines import DetectionPipeline
+from super_gradients.training.transforms.processing import Processing
+from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
+from super_gradients.training.utils.media.image import ImageSource
 
 
 
 
 class CustomizableDetector(SgModule):
 class CustomizableDetector(SgModule):
@@ -67,6 +70,12 @@ class CustomizableDetector(SgModule):
 
 
         self._initialize_weights(bn_eps, bn_momentum, inplace_act)
         self._initialize_weights(bn_eps, bn_momentum, inplace_act)
 
 
+        # Processing params
+        self._class_names: Optional[List[str]] = None
+        self._image_processor: Optional[Processing] = None
+        self._default_nms_iou: Optional[float] = None
+        self._default_nms_conf: Optional[float] = None
+
     def forward(self, x):
     def forward(self, x):
         x = self.backbone(x)
         x = self.backbone(x)
         x = self.neck(x)
         x = self.neck(x)
@@ -96,3 +105,70 @@ class CustomizableDetector(SgModule):
             self.heads_params = factory.insert_module_param(self.heads_params, "num_classes", new_num_classes)
             self.heads_params = factory.insert_module_param(self.heads_params, "num_classes", new_num_classes)
             self.heads = factory.get(factory.insert_module_param(self.heads_params, "in_channels", self.neck.out_channels))
             self.heads = factory.get(factory.insert_module_param(self.heads_params, "in_channels", self.neck.out_channels))
             self._initialize_weights(self.bn_eps, self.bn_momentum, self.inplace_act)
             self._initialize_weights(self.bn_eps, self.bn_momentum, self.inplace_act)
+
+    @staticmethod
+    def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
+        raise NotImplementedError
+
+    def set_dataset_processing_params(
+        self,
+        class_names: Optional[List[str]] = None,
+        image_processor: Optional[Processing] = None,
+        iou: Optional[float] = None,
+        conf: Optional[float] = None,
+    ) -> None:
+        """Set the processing parameters for the dataset.
+
+        :param class_names:     (Optional) Names of the dataset the model was trained on.
+        :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training.
+        :param iou:             (Optional) IoU threshold for the nms algorithm
+        :param conf:            (Optional) Below the confidence threshold, prediction are discarded
+        """
+        self._class_names = class_names or self._class_names
+        self._image_processor = image_processor or self._image_processor
+        self._default_nms_iou = iou or self._default_nms_iou
+        self._default_nms_conf = conf or self._default_nms_conf
+
+    def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline:
+        """Instantiate the prediction pipeline of this model.
+
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
+            raise RuntimeError(
+                "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
+            )
+
+        iou = iou or self._default_nms_iou
+        conf = conf or self._default_nms_conf
+
+        pipeline = DetectionPipeline(
+            model=self,
+            image_processor=self._image_processor,
+            post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
+            class_names=self._class_names,
+        )
+        return pipeline
+
+    def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction:
+        """Predict an image or a list of images.
+
+        :param images:  Images to predict.
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        pipeline = self._get_pipeline(iou=iou, conf=conf)
+        return pipeline(images)  # type: ignore
+
+    def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None):
+        """Predict using webcam.
+
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        pipeline = self._get_pipeline(iou=iou, conf=conf)
+        pipeline.predict_webcam()
Discard
@@ -12,9 +12,10 @@ from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head imp
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.utils import HpmStruct
 from super_gradients.training.models.arch_params_factory import get_arch_params
 from super_gradients.training.models.arch_params_factory import get_arch_params
 from super_gradients.training.models.detection_models.pp_yolo_e.post_prediction_callback import PPYoloEPostPredictionCallback, DetectionPostPredictionCallback
 from super_gradients.training.models.detection_models.pp_yolo_e.post_prediction_callback import PPYoloEPostPredictionCallback, DetectionPostPredictionCallback
-from super_gradients.training.models.results import DetectionResults
+from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
 from super_gradients.training.pipelines.pipelines import DetectionPipeline
 from super_gradients.training.pipelines.pipelines import DetectionPipeline
 from super_gradients.training.transforms.processing import Processing
 from super_gradients.training.transforms.processing import Processing
+from super_gradients.training.utils.media.image import ImageSource
 
 
 
 
 class PPYoloE(SgModule):
 class PPYoloE(SgModule):
@@ -29,34 +30,75 @@ class PPYoloE(SgModule):
 
 
         self._class_names: Optional[List[str]] = None
         self._class_names: Optional[List[str]] = None
         self._image_processor: Optional[Processing] = None
         self._image_processor: Optional[Processing] = None
+        self._default_nms_iou: Optional[float] = None
+        self._default_nms_conf: Optional[float] = None
 
 
     @staticmethod
     @staticmethod
     def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
     def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
         return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
         return PPYoloEPostPredictionCallback(score_threshold=conf, nms_threshold=iou, nms_top_k=1000, max_predictions=300)
 
 
-    def set_dataset_processing_params(self, class_names: Optional[List[str]], image_processor: Optional[Processing]) -> None:
+    def set_dataset_processing_params(
+        self,
+        class_names: Optional[List[str]] = None,
+        image_processor: Optional[Processing] = None,
+        iou: Optional[float] = None,
+        conf: Optional[float] = None,
+    ) -> None:
         """Set the processing parameters for the dataset.
         """Set the processing parameters for the dataset.
 
 
         :param class_names:     (Optional) Names of the dataset the model was trained on.
         :param class_names:     (Optional) Names of the dataset the model was trained on.
         :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training.
         :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training.
+        :param iou:             (Optional) IoU threshold for the nms algorithm
+        :param conf:            (Optional) Below the confidence threshold, prediction are discarded
         """
         """
         self._class_names = class_names or self._class_names
         self._class_names = class_names or self._class_names
         self._image_processor = image_processor or self._image_processor
         self._image_processor = image_processor or self._image_processor
+        self._default_nms_iou = iou or self._default_nms_iou
+        self._default_nms_conf = conf or self._default_nms_conf
 
 
-    def predict(self, images, iou: float = 0.65, conf: float = 0.01) -> DetectionResults:
+    def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline:
+        """Instantiate the prediction pipeline of this model.
 
 
-        if self._class_names is None or self._image_processor is None:
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
             raise RuntimeError(
             raise RuntimeError(
                 "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
                 "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
             )
             )
 
 
+        iou = iou or self._default_nms_iou
+        conf = conf or self._default_nms_conf
+
         pipeline = DetectionPipeline(
         pipeline = DetectionPipeline(
             model=self,
             model=self,
             image_processor=self._image_processor,
             image_processor=self._image_processor,
             post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
             post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
             class_names=self._class_names,
             class_names=self._class_names,
         )
         )
-        return pipeline(images)
+        return pipeline
+
+    def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction:
+        """Predict an image or a list of images.
+
+        :param images:  Images to predict.
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        pipeline = self._get_pipeline(iou=iou, conf=conf)
+        return pipeline(images)  # type: ignore
+
+    def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None):
+        """Predict using webcam.
+
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        pipeline = self._get_pipeline(iou=iou, conf=conf)
+        pipeline.predict_webcam()
 
 
     def forward(self, x: Tensor):
     def forward(self, x: Tensor):
         features = self.backbone(x)
         features = self.backbone(x)
Discard
@@ -11,10 +11,10 @@ from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.utils import torch_version_is_greater_or_equal
 from super_gradients.training.utils import torch_version_is_greater_or_equal
 from super_gradients.training.utils.detection_utils import non_max_suppression, matrix_non_max_suppression, NMS_Type, DetectionPostPredictionCallback, Anchors
 from super_gradients.training.utils.detection_utils import non_max_suppression, matrix_non_max_suppression, NMS_Type, DetectionPostPredictionCallback, Anchors
 from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param
 from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param
-from super_gradients.training.models.results import DetectionResults
+from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
 from super_gradients.training.pipelines.pipelines import DetectionPipeline
 from super_gradients.training.pipelines.pipelines import DetectionPipeline
 from super_gradients.training.transforms.processing import Processing
 from super_gradients.training.transforms.processing import Processing
-
+from super_gradients.training.utils.media.image import ImageSource
 
 
 COCO_DETECTION_80_CLASSES_BBOX_ANCHORS = Anchors(
 COCO_DETECTION_80_CLASSES_BBOX_ANCHORS = Anchors(
     [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], strides=[8, 16, 32]
     [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], strides=[8, 16, 32]
@@ -418,33 +418,75 @@ class YoloBase(SgModule):
 
 
         self._class_names: Optional[List[str]] = None
         self._class_names: Optional[List[str]] = None
         self._image_processor: Optional[Processing] = None
         self._image_processor: Optional[Processing] = None
+        self._default_nms_iou: Optional[float] = None
+        self._default_nms_conf: Optional[float] = None
 
 
     @staticmethod
     @staticmethod
     def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
     def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
         return YoloPostPredictionCallback(conf=conf, iou=iou)
         return YoloPostPredictionCallback(conf=conf, iou=iou)
 
 
-    def set_dataset_processing_params(self, class_names: Optional[List[str]], image_processor: Optional[Processing]) -> None:
+    def set_dataset_processing_params(
+        self,
+        class_names: Optional[List[str]] = None,
+        image_processor: Optional[Processing] = None,
+        iou: Optional[float] = None,
+        conf: Optional[float] = None,
+    ) -> None:
         """Set the processing parameters for the dataset.
         """Set the processing parameters for the dataset.
 
 
         :param class_names:     (Optional) Names of the dataset the model was trained on.
         :param class_names:     (Optional) Names of the dataset the model was trained on.
         :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training.
         :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training.
+        :param iou:             (Optional) IoU threshold for the nms algorithm
+        :param conf:            (Optional) Below the confidence threshold, prediction are discarded
         """
         """
         self._class_names = class_names or self._class_names
         self._class_names = class_names or self._class_names
         self._image_processor = image_processor or self._image_processor
         self._image_processor = image_processor or self._image_processor
+        self._default_nms_iou = iou or self._default_nms_iou
+        self._default_nms_conf = conf or self._default_nms_conf
+
+    def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline:
+        """Instantiate the prediction pipeline of this model.
 
 
-    def predict(self, images, iou: float = 0.65, conf: float = 0.01) -> DetectionResults:
-        if self._class_names is None or self._image_processor is None:
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
             raise RuntimeError(
             raise RuntimeError(
                 "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
                 "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
             )
             )
 
 
+        iou = iou or self._default_nms_iou
+        conf = conf or self._default_nms_conf
+
         pipeline = DetectionPipeline(
         pipeline = DetectionPipeline(
             model=self,
             model=self,
             image_processor=self._image_processor,
             image_processor=self._image_processor,
             post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
             post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
             class_names=self._class_names,
             class_names=self._class_names,
         )
         )
-        return pipeline(images)
+        return pipeline
+
+    def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction:
+        """Predict an image or a list of images.
+
+        :param images:  Images to predict.
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        pipeline = self._get_pipeline(iou=iou, conf=conf)
+        return pipeline(images)  # type: ignore
+
+    def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None):
+        """Predict using webcam.
+
+        :param iou:     (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
+        :param conf:    (Optional) Below the confidence threshold, prediction are discarded.
+                        If None, the default value associated to the training is used.
+        """
+        pipeline = self._get_pipeline(iou=iou, conf=conf)
+        pipeline.predict_webcam()
 
 
     def forward(self, x):
     def forward(self, x):
         out = self._backbone(x)
         out = self._backbone(x)
Discard
@@ -136,8 +136,9 @@ def instantiate_model(
                 net.replace_head(new_num_classes=num_classes_new_head)
                 net.replace_head(new_num_classes=num_classes_new_head)
                 arch_params.num_classes = num_classes_new_head
                 arch_params.num_classes = num_classes_new_head
 
 
-            class_names, image_processor = get_pretrained_processing_params(model_name, pretrained_weights)
-            net.set_dataset_processing_params(class_names, image_processor)
+            # TODO: remove once we load it from the checkpoint
+            processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
+            net.set_dataset_processing_params(**processing_params)
 
 
     _add_model_name_attribute(net, model_name)
     _add_model_name_attribute(net, model_name)
 
 
Discard
@@ -1,17 +1,18 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, Iterator
 from dataclasses import dataclass
 from dataclasses import dataclass
-from matplotlib import pyplot as plt
 
 
 import numpy as np
 import numpy as np
 
 
 from super_gradients.training.utils.detection_utils import DetectionVisualization
 from super_gradients.training.utils.detection_utils import DetectionVisualization
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
+from super_gradients.training.utils.media.video import show_video_from_frames
+from super_gradients.training.utils.media.image import show_image
 
 
 
 
 @dataclass
 @dataclass
-class Result(ABC):
-    """Results of a given computer vision task (detection, classification, etc.).
+class ImagePrediction(ABC):
+    """Object wrapping an image and a model's prediction.
 
 
     :attr image:        Input image
     :attr image:        Input image
     :attr predictions:  Predictions of the model
     :attr predictions:  Predictions of the model
@@ -19,7 +20,7 @@ class Result(ABC):
     """
     """
 
 
     image: np.ndarray
     image: np.ndarray
-    predictions: Prediction
+    prediction: Prediction
     class_names: List[str]
     class_names: List[str]
 
 
     @abstractmethod
     @abstractmethod
@@ -34,28 +35,8 @@ class Result(ABC):
 
 
 
 
 @dataclass
 @dataclass
-class Results(ABC):
-    """List of results of a given computer vision task (detection, classification, etc.).
-
-    :attr results: List of results of the run
-    """
-
-    results: List[Result]
-
-    @abstractmethod
-    def draw(self) -> List[np.ndarray]:
-        """Draw the predictions on the image."""
-        pass
-
-    @abstractmethod
-    def show(self) -> None:
-        """Display the predictions on the image."""
-        pass
-
-
-@dataclass
-class DetectionResult(Result):
-    """Result of a detection task.
+class ImageDetectionPrediction(ImagePrediction):
+    """Object wrapping an image and a detection model's prediction.
 
 
     :attr image:        Input image
     :attr image:        Input image
     :attr predictions:  Predictions of the model
     :attr predictions:  Predictions of the model
@@ -63,7 +44,7 @@ class DetectionResult(Result):
     """
     """
 
 
     image: np.ndarray
     image: np.ndarray
-    predictions: DetectionPrediction
+    prediction: DetectionPrediction
     class_names: List[str]
     class_names: List[str]
 
 
     def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> np.ndarray:
     def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> np.ndarray:
@@ -78,18 +59,18 @@ class DetectionResult(Result):
         image_np = self.image.copy()
         image_np = self.image.copy()
         color_mapping = color_mapping or DetectionVisualization._generate_color_mapping(len(self.class_names))
         color_mapping = color_mapping or DetectionVisualization._generate_color_mapping(len(self.class_names))
 
 
-        for pred_i in range(len(self.predictions)):
+        for pred_i in range(len(self.prediction)):
             image_np = DetectionVisualization._draw_box_title(
             image_np = DetectionVisualization._draw_box_title(
                 color_mapping=color_mapping,
                 color_mapping=color_mapping,
                 class_names=self.class_names,
                 class_names=self.class_names,
                 box_thickness=box_thickness,
                 box_thickness=box_thickness,
                 image_np=image_np,
                 image_np=image_np,
-                x1=int(self.predictions.bboxes_xyxy[pred_i, 0]),
-                y1=int(self.predictions.bboxes_xyxy[pred_i, 1]),
-                x2=int(self.predictions.bboxes_xyxy[pred_i, 2]),
-                y2=int(self.predictions.bboxes_xyxy[pred_i, 3]),
-                class_id=int(self.predictions.labels[pred_i]),
-                pred_conf=self.predictions.confidence[pred_i] if show_confidence else None,
+                x1=int(self.prediction.bboxes_xyxy[pred_i, 0]),
+                y1=int(self.prediction.bboxes_xyxy[pred_i, 1]),
+                x2=int(self.prediction.bboxes_xyxy[pred_i, 2]),
+                y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
+                class_id=int(self.prediction.labels[pred_i]),
+                pred_conf=self.prediction.confidence[pred_i] if show_confidence else None,
             )
             )
         return image_np
         return image_np
 
 
@@ -102,34 +83,80 @@ class DetectionResult(Result):
                                 Default is None, which generates a default color mapping based on the number of class names.
                                 Default is None, which generates a default color mapping based on the number of class names.
         """
         """
         image_np = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
         image_np = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
+        show_image(image_np)
+
+
+@dataclass
+class ImagesPredictions(ABC):
+    """Object wrapping the list of image predictions.
 
 
-        plt.imshow(image_np, interpolation="nearest")
-        plt.axis("off")
-        plt.show()
+    :attr _images_prediction_lst: List of results of the run
+    """
+
+    _images_prediction_lst: List[ImagePrediction]
+
+    def __len__(self) -> int:
+        return len(self._images_prediction_lst)
+
+    def __getitem__(self, index: int) -> ImagePrediction:
+        return self._images_prediction_lst[index]
+
+    def __iter__(self) -> Iterator[ImagePrediction]:
+        return iter(self._images_prediction_lst)
+
+    @abstractmethod
+    def show(self) -> None:
+        pass
 
 
 
 
 @dataclass
 @dataclass
-class DetectionResults(Results):
-    """Results of a detection task.
+class VideoPredictions(ImagesPredictions, ABC):
+    """Object wrapping the list of image predictions as a Video.
 
 
-    :attr results:  List of the predictions results
+    :attr _images_prediction_lst:   List of results of the run
+    :att fps:                       Frames per second of the video
     """
     """
 
 
-    def __init__(self, images: List[np.ndarray], predictions: List[DetectionPrediction], class_names: List[str]):
-        self.results: List[DetectionResult] = []
-        for image, prediction in zip(images, predictions):
-            self.results.append(DetectionResult(image=image, predictions=prediction, class_names=class_names))
+    _images_prediction_lst: List[ImagePrediction]
+    fps: float
+
+    @abstractmethod
+    def show(self, *args, **kwargs) -> None:
+        """Display the predictions on the image."""
+        pass
+
+
+@dataclass
+class ImagesDetectionPrediction(ImagesPredictions):
+    """Object wrapping the list of image detection predictions.
+
+    :attr _images_prediction_lst:  List of the predictions results
+    """
 
 
-    def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> List[np.ndarray]:
-        """Draw the predicted bboxes on the images.
+    _images_prediction_lst: List[ImageDetectionPrediction]
+
+    def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> None:
+        """Display the predicted bboxes on the images.
 
 
         :param box_thickness:   Thickness of bounding boxes.
         :param box_thickness:   Thickness of bounding boxes.
         :param show_confidence: Whether to show confidence scores on the image.
         :param show_confidence: Whether to show confidence scores on the image.
         :param color_mapping:   List of tuples representing the colors for each class.
         :param color_mapping:   List of tuples representing the colors for each class.
                                 Default is None, which generates a default color mapping based on the number of class names.
                                 Default is None, which generates a default color mapping based on the number of class names.
-        :return:                List of Images with predicted bboxes for each image. Note that this does not modify the original images.
         """
         """
-        return [prediction.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for prediction in self.results]
+        for prediction in self._images_prediction_lst:
+            prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
+
+
+@dataclass
+class VideoDetectionPrediction(VideoPredictions):
+    """Object wrapping the list of image detection predictions as a Video.
+
+    :attr _images_prediction_lst:   List of the predictions results
+    :att fps:                       Frames per second of the video
+    """
+
+    _images_prediction_lst: List[ImageDetectionPrediction]
+    fps: float
 
 
     def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> None:
     def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int]]] = None) -> None:
         """Display the predicted bboxes on the images.
         """Display the predicted bboxes on the images.
@@ -139,5 +166,7 @@ class DetectionResults(Results):
         :param color_mapping:   List of tuples representing the colors for each class.
         :param color_mapping:   List of tuples representing the colors for each class.
                                 Default is None, which generates a default color mapping based on the number of class names.
                                 Default is None, which generates a default color mapping based on the number of class names.
         """
         """
-        for prediction in self.results:
-            prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
+        frames = [
+            result.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for result in self._images_prediction_lst
+        ]
+        show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)
Discard
@@ -3,7 +3,7 @@ from typing import Union
 from torch import nn
 from torch import nn
 
 
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
-from super_gradients.training.models.results import Result
+from super_gradients.training.models.prediction_results import ImagesPredictions
 
 
 
 
 class SgModule(nn.Module):
 class SgModule(nn.Module):
@@ -64,9 +64,12 @@ class SgModule(nn.Module):
 
 
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def predict(self, images, *args, **kwargs) -> Result:
+    def predict(self, images, *args, **kwargs) -> ImagesPredictions:
         raise NotImplementedError(f"`predict` is not implemented for {self.__class__.__name__}.")
         raise NotImplementedError(f"`predict` is not implemented for {self.__class__.__name__}.")
 
 
+    def predict_webcam(self, *args, **kwargs) -> None:
+        raise NotImplementedError(f"`predict_webcam` is not implemented for {self.__class__.__name__}.")
+
     def set_dataset_processing_params(self, *args, **kwargs) -> None:
     def set_dataset_processing_params(self, *args, **kwargs) -> None:
         """Set the processing parameters for the dataset."""
         """Set the processing parameters for the dataset."""
         pass
         pass
Discard
@@ -1,16 +1,30 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, Iterable
 from contextlib import contextmanager
 from contextlib import contextmanager
+from tqdm import tqdm
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
-
-from super_gradients.training.utils.load_image import load_images, ImageType
+from super_gradients.training.utils.utils import generate_batch
+from super_gradients.training.utils.media.video import load_video, is_video
+from super_gradients.training.utils.media.image import ImageSource, check_image_typing
+from super_gradients.training.utils.media.stream import WebcamStreaming
 from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
 from super_gradients.training.models.sg_module import SgModule
 from super_gradients.training.models.sg_module import SgModule
-from super_gradients.training.models.results import Results, DetectionResults
+from super_gradients.training.models.prediction_results import (
+    ImagesDetectionPrediction,
+    VideoDetectionPrediction,
+    ImagePrediction,
+    ImageDetectionPrediction,
+    ImagesPredictions,
+    VideoPredictions,
+)
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
 from super_gradients.training.transforms.processing import Processing, ComposeProcessing
 from super_gradients.training.transforms.processing import Processing, ComposeProcessing
+from super_gradients.common.abstractions.abstract_logger import get_logger
+
+
+logger = get_logger(__name__)
 
 
 
 
 @contextmanager
 @contextmanager
@@ -35,39 +49,101 @@ class Pipeline(ABC):
     :param device:          The device on which the model will be run. Defaults to "cpu". Use "cuda" for GPU support.
     :param device:          The device on which the model will be run. Defaults to "cpu". Use "cuda" for GPU support.
     """
     """
 
 
-    def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], device: Optional[str] = "cpu"):
+    def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], class_names: List[str], device: Optional[str] = "cpu"):
         super().__init__()
         super().__init__()
         self.model = model.to(device)
         self.model = model.to(device)
         self.device = device
         self.device = device
+        self.class_names = class_names
 
 
         if isinstance(image_processor, list):
         if isinstance(image_processor, list):
             image_processor = ComposeProcessing(image_processor)
             image_processor = ComposeProcessing(image_processor)
         self.image_processor = image_processor
         self.image_processor = image_processor
 
 
-    @abstractmethod
-    def __call__(self, images: Union[ImageType, List[ImageType]]) -> Union[Results, Tuple[List[np.ndarray], List[Prediction]]]:
-        """Apply the pipeline on images and return the result.
+    def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions:
+        """Predict an image or a list of images.
+
+        Supported types include:
+            - str:              A string representing either a video, an image or an URL.
+            - numpy.ndarray:    A numpy array representing the image
+            - torch.Tensor:     A PyTorch tensor representing the image
+            - PIL.Image.Image:  A PIL Image object
+            - List:             A list of images of any of the above image types (list of videos not supported).
 
 
-        :param images:  Single image or a list of images of supported types.
-        :return         Results object containing the results of the prediction and the image.
+        :param inputs:      inputs to the model, which can be any of the above-mentioned types.
+        :param batch_size:  Number of images to be processed at the same time.
+        :return:            Results of the prediction.
         """
         """
-        return self._run(images=images)
-
-    def _run(self, images: Union[ImageType, List[ImageType]]) -> Tuple[List[np.ndarray], List[Prediction]]:
-        """Run the pipeline and return (image, predictions). The pipeline is made of 4 steps:
-        1. Load images - Loading the images into a list of numpy arrays.
-        2. Preprocess - Encode the image in the shape/format expected by the model
-        3. Predict - Run the model on the preprocessed image
-        4. Postprocess - Decode the output of the model so that the predictions are in the shape/format of original image.
-
-        :param images:  Single image or a list of images of supported types.
-        :return:
-            - List of numpy arrays representing images.
-            - List of model predictions.
+
+        if is_video(inputs):
+            return self.predict_video(inputs, batch_size)
+        elif check_image_typing(inputs):
+            return self.predict_images(inputs, batch_size)
+        else:
+            raise ValueError(f"Input {inputs} not supported for prediction.")
+
+    def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions:
+        """Predict an image or a list of images.
+
+        :param images:      Images to predict.
+        :param batch_size:  The size of each batch.
+        :return:            Results of the prediction.
         """
         """
-        self.model = self.model.to(self.device)  # Make sure the model is on the correct device, as it might have been moved after init
+        from super_gradients.training.utils.media.image import load_images
 
 
         images = load_images(images)
         images = load_images(images)
+        result_generator = self._generate_prediction_result(images=images, batch_size=batch_size)
+        return self._combine_image_prediction_to_images(result_generator, n_images=len(images))
+
+    def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> VideoPredictions:
+        """Predict on a video file, by processing the frames in batches.
+
+        :param video_path:  Path to the video file.
+        :param batch_size:  The size of each batch.
+        :return:            Results of the prediction.
+        """
+        video_frames, fps = load_video(file_path=video_path)
+        result_generator = self._generate_prediction_result(images=video_frames, batch_size=batch_size)
+        return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))
+
+    def predict_webcam(self) -> None:
+        """Predict using webcam"""
+
+        def _draw_predictions(frame: np.ndarray) -> np.ndarray:
+            """Draw the predictions on a single frame from the stream."""
+            frame_prediction = next(iter(self._generate_prediction_result(images=[frame])))
+            return frame_prediction.draw()
+
+        video_streaming = WebcamStreaming(frame_processing_fn=_draw_predictions, fps_update_frequency=1)
+        video_streaming.run()
+
+    def _generate_prediction_result(self, images: Iterable[np.ndarray], batch_size: Optional[int] = None) -> Iterable[ImagePrediction]:
+        """Run the pipeline on the images as single batch or through multiple batches.
+
+        NOTE: A core motivation to have this function as a generator is that it can be used in a lazy way (if images is generator itself),
+              i.e. without having to load all the images into memory.
+
+        :param images:      Iterable of numpy arrays representing images.
+        :param batch_size:  The size of each batch.
+        :return:            Iterable of Results object, each containing the results of the prediction and the image.
+        """
+        if batch_size is None:
+            yield from self._generate_prediction_result_single_batch(images)
+        else:
+            for batch_images in generate_batch(images, batch_size):
+                yield from self._generate_prediction_result_single_batch(batch_images)
+
+    def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray]) -> Iterable[ImagePrediction]:
+        """Run the pipeline on images. The pipeline is made of 4 steps:
+            1. Load images - Loading the images into a list of numpy arrays.
+            2. Preprocess - Encode the image in the shape/format expected by the model
+            3. Predict - Run the model on the preprocessed image
+            4. Postprocess - Decode the output of the model so that the predictions are in the shape/format of original image.
+
+        :param images:  Iterable of numpy arrays representing images.
+        :return:        Iterable of Results object, each containing the results of the prediction and the image.
+        """
+        images = list(images)  # We need to load all the images into memory, and to reuse it afterwards.
+        self.model = self.model.to(self.device)  # Make sure the model is on the correct device, as it might have been moved after init
 
 
         # Preprocess
         # Preprocess
         preprocessed_images, processing_metadatas = [], []
         preprocessed_images, processing_metadatas = [], []
@@ -84,11 +160,13 @@ class Pipeline(ABC):
 
 
         # Postprocess
         # Postprocess
         postprocessed_predictions = []
         postprocessed_predictions = []
-        for prediction, processing_metadata in zip(predictions, processing_metadatas):
+        for image, prediction, processing_metadata in zip(images, predictions, processing_metadatas):
             prediction = self.image_processor.postprocess_predictions(predictions=prediction, metadata=processing_metadata)
             prediction = self.image_processor.postprocess_predictions(predictions=prediction, metadata=processing_metadata)
             postprocessed_predictions.append(prediction)
             postprocessed_predictions.append(prediction)
 
 
-        return images, postprocessed_predictions
+        # Yield results one by one
+        for image, prediction in zip(images, postprocessed_predictions):
+            yield self._instantiate_image_prediction(image=image, prediction=prediction)
 
 
     @abstractmethod
     @abstractmethod
     def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[Prediction]:
     def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[Prediction]:
@@ -98,7 +176,40 @@ class Pipeline(ABC):
         :param model_input:     Model input (i.e. images after preprocessing).
         :param model_input:     Model input (i.e. images after preprocessing).
         :return:                Model predictions, without any post-processing.
         :return:                Model predictions, without any post-processing.
         """
         """
-        pass
+        raise NotImplementedError
+
+    @abstractmethod
+    def _instantiate_image_prediction(self, image: np.ndarray, prediction: Prediction) -> ImagePrediction:
+        """Instantiate an object wrapping an image and the pipeline's prediction.
+
+        :param image:       Image to predict.
+        :param prediction:  Model prediction on that image.
+        :return:            Object wrapping an image and the pipeline's prediction.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _combine_image_prediction_to_images(self, images_prediction_lst: Iterable[ImagePrediction], n_images: Optional[int] = None) -> ImagesPredictions:
+        """Instantiate an object wrapping the list of images and the pipeline's predictions on them.
+
+        :param images_prediction_lst:   List of image predictions.
+        :param n_images:                (Optional) Number of images in the list. This used for tqdm progress bar to work with iterables, but is not required.
+        :return:                        Object wrapping the list of image predictions.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _combine_image_prediction_to_video(
+        self, images_prediction_lst: Iterable[ImagePrediction], fps: float, n_images: Optional[int] = None
+    ) -> VideoPredictions:
+        """Instantiate an object holding the video frames and the pipeline's predictions on it.
+
+        :param images_prediction_lst:   List of image predictions.
+        :param fps:                     Frames per second.
+        :param n_images:                (Optional) Number of images in the list. This used for tqdm progress bar to work with iterables, but is not required.
+        :return:                        Object wrapping the list of image predictions as a Video.
+        """
+        raise NotImplementedError
 
 
 
 
 class DetectionPipeline(Pipeline):
 class DetectionPipeline(Pipeline):
@@ -120,18 +231,8 @@ class DetectionPipeline(Pipeline):
         device: Optional[str] = "cpu",
         device: Optional[str] = "cpu",
         image_processor: Optional[Processing] = None,
         image_processor: Optional[Processing] = None,
     ):
     ):
-        super().__init__(model=model, device=device, image_processor=image_processor)
+        super().__init__(model=model, device=device, image_processor=image_processor, class_names=class_names)
         self.post_prediction_callback = post_prediction_callback
         self.post_prediction_callback = post_prediction_callback
-        self.class_names = class_names
-
-    def __call__(self, images: Union[List[ImageType], ImageType]) -> DetectionResults:
-        """Apply the pipeline on images and return the detection result.
-
-        :param images:  Single image or a list of images of supported types.
-        :return         Results object containing the results of the prediction and the image.
-        """
-        images, predictions = super().__call__(images=images)
-        return DetectionResults(images=images, predictions=predictions, class_names=self.class_names)
 
 
     def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]:
     def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]:
         """Decode the model output, by applying post prediction callback. This includes NMS.
         """Decode the model output, by applying post prediction callback. This includes NMS.
@@ -144,7 +245,7 @@ class DetectionPipeline(Pipeline):
 
 
         predictions = []
         predictions = []
         for prediction, image in zip(post_nms_predictions, model_input):
         for prediction, image in zip(post_nms_predictions, model_input):
-            prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32)
+            prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32)
             prediction = prediction.detach().cpu().numpy()
             prediction = prediction.detach().cpu().numpy()
             predictions.append(
             predictions.append(
                 DetectionPrediction(
                 DetectionPrediction(
@@ -157,3 +258,18 @@ class DetectionPipeline(Pipeline):
             )
             )
 
 
         return predictions
         return predictions
+
+    def _instantiate_image_prediction(self, image: np.ndarray, prediction: DetectionPrediction) -> ImagePrediction:
+        return ImageDetectionPrediction(image=image, prediction=prediction, class_names=self.class_names)
+
+    def _combine_image_prediction_to_images(
+        self, images_predictions: Iterable[ImageDetectionPrediction], n_images: Optional[int] = None
+    ) -> ImagesDetectionPrediction:
+        images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
+        return ImagesDetectionPrediction(_images_prediction_lst=images_predictions)
+
+    def _combine_image_prediction_to_video(
+        self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
+    ) -> VideoDetectionPrediction:
+        images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
+        return VideoDetectionPrediction(_images_prediction_lst=images_predictions, fps=fps)
Discard
@@ -1,9 +1,10 @@
-from typing import Tuple, List, Union, Optional
+from typing import Tuple, List, Union
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from dataclasses import dataclass
 
 
 import numpy as np
 import numpy as np
 
 
+from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
 from super_gradients.training.models.predictions import Prediction, DetectionPrediction
 from super_gradients.training.transforms.utils import (
 from super_gradients.training.transforms.utils import (
     _rescale_image,
     _rescale_image,
@@ -96,6 +97,49 @@ class ImagePermute(Processing):
         return predictions
         return predictions
 
 
 
 
+class ReverseImageChannels(Processing):
+    """Reverse the order of the image channels (RGB -> BGR or BGR -> RGB)."""
+
+    def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
+        """Reverse the channel order of an image.
+
+        :param image: Image, in (H, W, C) format.
+        :return:      Image with reversed channel order. (RGB if input was BGR, BGR if input was RGB)
+        """
+
+        if image.shape[2] != 3:
+            raise ValueError("ReverseImageChannels expects 3 channels, got: " + str(image.shape[2]))
+
+        processed_image = image[..., ::-1]
+        return processed_image, None
+
+    def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction:
+        return predictions
+
+
+class StandardizeImage(Processing):
+    """Standardize image pixel values with img/max_val
+
+    :param max_value: Current maximum value of the image pixels. (usually 255)
+    """
+
+    def __init__(self, max_value: float = 255.0):
+        super().__init__()
+        self.max_value = max_value
+
+    def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
+        """Reverse the channel order of an image.
+
+        :param image: Image, in (H, W, C) format.
+        :return:      Image with reversed channel order. (RGB if input was BGR, BGR if input was RGB)
+        """
+        processed_image = (image / self.max_value).astype(np.float32)
+        return processed_image, None
+
+    def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction:
+        return predictions
+
+
 class NormalizeImage(Processing):
 class NormalizeImage(Processing):
     """Normalize an image based on means and standard deviation.
     """Normalize an image based on means and standard deviation.
 
 
@@ -204,41 +248,84 @@ class DetectionLongestMaxSizeRescale(_LongestMaxSizeRescale):
         return predictions
         return predictions
 
 
 
 
-def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -> Tuple[Optional[List[str]], Optional[Processing]]:
-    """Get the processing parameters for a pretrained model."""
-    if "yolox" in model_name and pretrained_weights == "coco":
-        return default_yolox_coco_processing_params()
-    elif "ppyoloe" in model_name and pretrained_weights == "coco":
-        return default_ppyoloe_coco_processing_params()
-    else:
-        return None, None
-
-
-def default_yolox_coco_processing_params() -> Tuple[List[str], Processing]:
-    """Processing parameters commonly used for training YoloX on COCO dataset."""
-    from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
+def default_yolox_coco_processing_params() -> dict:
+    """Processing parameters commonly used for training YoloX on COCO dataset.
+    TODO: remove once we load it from the checkpoint
+    """
 
 
     image_processor = ComposeProcessing(
     image_processor = ComposeProcessing(
         [
         [
+            ReverseImageChannels(),
             DetectionLongestMaxSizeRescale((640, 640)),
             DetectionLongestMaxSizeRescale((640, 640)),
             DetectionBottomRightPadding((640, 640), 114),
             DetectionBottomRightPadding((640, 640), 114),
             ImagePermute((2, 0, 1)),
             ImagePermute((2, 0, 1)),
         ]
         ]
     )
     )
-    class_names = COCO_DETECTION_CLASSES_LIST
-    return class_names, image_processor
+
+    params = dict(
+        class_names=COCO_DETECTION_CLASSES_LIST,
+        image_processor=image_processor,
+        iou=0.65,
+        conf=0.1,
+    )
+    return params
 
 
 
 
-def default_ppyoloe_coco_processing_params() -> Tuple[List[str], Processing]:
-    """Processing parameters commonly used for training PPYoloE on COCO dataset."""
-    from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
+def default_ppyoloe_coco_processing_params() -> dict:
+    """Processing parameters commonly used for training PPYoloE on COCO dataset.
+    TODO: remove once we load it from the checkpoint
+    """
 
 
     image_processor = ComposeProcessing(
     image_processor = ComposeProcessing(
         [
         [
+            ReverseImageChannels(),
             DetectionRescale(output_shape=(640, 640)),
             DetectionRescale(output_shape=(640, 640)),
             NormalizeImage(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
             NormalizeImage(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
             ImagePermute(permutation=(2, 0, 1)),
             ImagePermute(permutation=(2, 0, 1)),
         ]
         ]
     )
     )
-    class_names = COCO_DETECTION_CLASSES_LIST
-    return class_names, image_processor
+
+    params = dict(
+        class_names=COCO_DETECTION_CLASSES_LIST,
+        image_processor=image_processor,
+        iou=0.65,
+        conf=0.5,
+    )
+    return params
+
+
+def default_deciyolo_coco_processing_params() -> dict:
+    """Processing parameters commonly used for training DeciYolo on COCO dataset.
+    TODO: remove once we load it from the checkpoint
+    """
+
+    image_processor = ComposeProcessing(
+        [
+            DetectionLongestMaxSizeRescale(output_shape=(636, 636)),
+            DetectionCenterPadding(output_shape=(640, 640), pad_value=114),
+            StandardizeImage(max_value=255.0),
+            ImagePermute(permutation=(2, 0, 1)),
+        ]
+    )
+
+    params = dict(
+        class_names=COCO_DETECTION_CLASSES_LIST,
+        image_processor=image_processor,
+        iou=0.65,
+        conf=0.5,
+    )
+    return params
+
+
+def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -> dict:
+    """Get the processing parameters for a pretrained model.
+    TODO: remove once we load it from the checkpoint
+    """
+    if pretrained_weights == "coco":
+        if "yolox" in model_name:
+            return default_yolox_coco_processing_params()
+        elif "ppyoloe" in model_name:
+            return default_ppyoloe_coco_processing_params()
+        elif "deciyolo" in model_name:
+            return default_deciyolo_coco_processing_params()
+    return dict()
Discard
@@ -1,87 +0,0 @@
-from typing import Union, List
-import PIL
-
-import numpy as np
-import torch
-import requests
-from urllib.parse import urlparse
-
-IMG_EXTENSIONS = ("bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm")
-ImageType = Union[str, np.ndarray, torch.Tensor, PIL.Image.Image]
-
-
-def load_images(images: Union[List[ImageType], ImageType]) -> List[np.ndarray]:
-    """Load a single image or a list of images and return them as a list of numpy arrays.
-
-    Supported image types include:
-        - numpy.ndarray:    A numpy array representing the image
-        - torch.Tensor:     A PyTorch tensor representing the image
-        - PIL.Image.Image:  A PIL Image object
-        - str:              A string representing either a local file path or a URL to an image
-
-    :param images:  Single image or a list of images of supported types.
-    :return:        List of images as numpy arrays. If loaded from string, the image will be returned as RGB.
-    """
-    if isinstance(images, list):
-        return [load_image(image=image) for image in images]
-    else:
-        return [load_image(image=images)]
-
-
-def load_image(image: ImageType) -> np.ndarray:
-    """Load a single image and return it as a numpy arrays.
-
-    Supported image types include:
-        - numpy.ndarray:    A numpy array representing the image
-        - torch.Tensor:     A PyTorch tensor representing the image
-        - PIL.Image.Image:  A PIL Image object
-        - str:              A string representing either a local file path or a URL to an image
-
-    :param image: Single image of supported types.
-    :return:      Image as numpy arrays. If loaded from string, the image will be returned as RGB.
-    """
-    if isinstance(image, np.ndarray):
-        return image
-    elif isinstance(image, torch.Tensor):
-        return image.numpy()
-    elif isinstance(image, PIL.Image.Image):
-        return load_np_image_from_pil(image)
-    elif isinstance(image, str):
-        image = load_pil_image_from_str(image_str=image)
-        return load_np_image_from_pil(image)
-    else:
-        raise ValueError(f"Unsupported image type: {type(image)}")
-
-
-def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray:
-    """Convert a PIL image to numpy array in RGB format."""
-    return np.asarray(image.convert("RGB"))
-
-
-def load_pil_image_from_str(image_str: str) -> PIL.Image.Image:
-    """Load an image based on a string (local file path or URL)."""
-
-    if is_url(image_str):
-        response = requests.get(image_str, stream=True)
-        response.raise_for_status()
-        return PIL.Image.open(response.raw)
-    else:
-        return PIL.Image.open(image_str)
-
-
-def is_url(url: str) -> bool:
-    """Check if the given string is a URL."""
-    try:
-        result = urlparse(url)
-        return all([result.scheme, result.netloc, result.path])
-    except Exception:
-        return False
-
-
-def is_image(filename: str) -> bool:
-    """Check if the given file name refers to image.
-
-    :param filename:    The filename to check.
-    :return:            True if the file is an image, False otherwise.
-    """
-    return filename.split(".")[-1].lower() in IMG_EXTENSIONS
Discard
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    1. from typing import Union, List, Iterable, Iterator
    2. from typing_extensions import get_args
    3. import PIL
    4. import os
    5. from PIL import Image
    6. import matplotlib.pyplot as plt
    7. import numpy as np
    8. import torch
    9. import requests
    10. from urllib.parse import urlparse
    11. IMG_EXTENSIONS = ("bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm")
    12. SingleImageSource = Union[str, np.ndarray, torch.Tensor, PIL.Image.Image]
    13. ImageSource = Union[SingleImageSource, List[SingleImageSource]]
    14. def load_images(images: Union[List[ImageSource], ImageSource]) -> List[np.ndarray]:
    15. """Load a single image or a list of images and return them as a list of numpy arrays.
    16. Supported types include:
    17. - str: A string representing either an image or an URL.
    18. - numpy.ndarray: A numpy array representing the image
    19. - torch.Tensor: A PyTorch tensor representing the image
    20. - PIL.Image.Image: A PIL Image object
    21. - List: A list of images of any of the above types.
    22. :param images: Single image or a list of images of supported types.
    23. :return: List of images as numpy arrays. If loaded from string, the image will be returned as RGB.
    24. """
    25. return [image for image in generate_image_loader(images=images)]
    26. def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iterable[np.ndarray]:
    27. """Generator that loads images one at a time.
    28. Supported types include:
    29. - str: A string representing either an image or an URL.
    30. - numpy.ndarray: A numpy array representing the image
    31. - torch.Tensor: A PyTorch tensor representing the image
    32. - PIL.Image.Image: A PIL Image object
    33. - List: A list of images of any of the above types.
    34. :param images: Single image or a list of images of supported types.
    35. :return: Generator of images as numpy arrays. If loaded from string, the image will be returned as RGB.
    36. """
    37. if isinstance(images, str) and os.path.isdir(images):
    38. images_paths = list_images_in_folder(images)
    39. for image_path in images_paths:
    40. yield load_image(image=image_path)
    41. elif isinstance(images, (list, Iterator)):
    42. for image in images:
    43. yield load_image(image=image)
    44. else:
    45. yield load_image(image=images)
    46. def list_images_in_folder(directory: str) -> List[str]:
    47. """List all the images in a directory.
    48. :param directory: The path to the directory containing the images.
    49. :return: A list of image file names.
    50. """
    51. files = os.listdir(directory)
    52. images_paths = [os.path.join(directory, f) for f in files if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))]
    53. return images_paths
    54. def load_image(image: ImageSource) -> np.ndarray:
    55. """Load a single image and return it as a numpy arrays.
    56. Supported image types include:
    57. - numpy.ndarray: A numpy array representing the image
    58. - torch.Tensor: A PyTorch tensor representing the image
    59. - PIL.Image.Image: A PIL Image object
    60. - str: A string representing either a local file path or a URL to an image
    61. :param image: Single image of supported types.
    62. :return: Image as numpy arrays. If loaded from string, the image will be returned as RGB.
    63. """
    64. if isinstance(image, np.ndarray):
    65. return image
    66. elif isinstance(image, torch.Tensor):
    67. return image.numpy()
    68. elif isinstance(image, PIL.Image.Image):
    69. return load_np_image_from_pil(image)
    70. elif isinstance(image, str):
    71. image = load_pil_image_from_str(image_str=image)
    72. return load_np_image_from_pil(image)
    73. else:
    74. raise ValueError(f"Unsupported image type: {type(image)}")
    75. def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray:
    76. """Convert a PIL image to numpy array in RGB format."""
    77. return np.asarray(image.convert("RGB"))
    78. def load_pil_image_from_str(image_str: str) -> PIL.Image.Image:
    79. """Load an image based on a string (local file path or URL)."""
    80. if is_url(image_str):
    81. response = requests.get(image_str, stream=True)
    82. response.raise_for_status()
    83. return PIL.Image.open(response.raw)
    84. else:
    85. return PIL.Image.open(image_str)
    86. def save_image(image: np.ndarray, path: str) -> None:
    87. """Save a numpy array as an image.
    88. :param image: Image to save, (H, W, C), RGB.
    89. :param path: Path to save the image to.
    90. """
    91. Image.fromarray(image).save(path)
    92. def is_url(url: str) -> bool:
    93. """Check if the given string is a URL.
    94. :param url: String to check.
    95. """
    96. try:
    97. result = urlparse(url)
    98. return all([result.scheme, result.netloc, result.path])
    99. except Exception:
    100. return False
    101. def show_image(image: np.ndarray) -> None:
    102. """Show an image using matplotlib.
    103. :param image: Image to show in (H, W, C), RGB.
    104. """
    105. plt.imshow(image, interpolation="nearest")
    106. plt.axis("off")
    107. plt.show()
    108. def check_image_typing(image: ImageSource) -> bool:
    109. """Check if the given object respects typing of image.
    110. :param image: Image to check.
    111. :return: True if the object is an image, False otherwise.
    112. """
    113. if isinstance(image, get_args(SingleImageSource)):
    114. return True
    115. elif isinstance(image, list):
    116. return all([isinstance(image_item, get_args(SingleImageSource)) for image_item in image])
    117. else:
    118. return False
    119. def is_image(filename: str) -> bool:
    120. """Check if the given file name refers to image.
    121. :param filename: The filename to check.
    122. :return: True if the file is an image, False otherwise.
    123. """
    124. return filename.split(".")[-1].lower() in IMG_EXTENSIONS
    Discard
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    1. import cv2
    2. import numpy as np
    3. import time
    4. from typing import Callable, Optional
    5. __all__ = ["WebcamStreaming"]
    6. class WebcamStreaming:
    7. """Stream video from a webcam. Press 'q' to quit the streaming.
    8. :param window_name: Name of the window to display the video stream.
    9. :param frame_processing_fn: Function to apply to each frame before displaying it.
    10. If None, frames are displayed as is.
    11. :param capture: ID of the video capture device to use.
    12. Default is cv2.CAP_ANY (which selects the first available device).
    13. :param fps_update_frequency: Minimum time (in seconds) between updates to the FPS counter.
    14. If None, the counter is updated every frame.
    15. """
    16. def __init__(
    17. self,
    18. window_name: str = "",
    19. frame_processing_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
    20. capture: int = cv2.CAP_ANY,
    21. fps_update_frequency: Optional[float] = None,
    22. ):
    23. self.window_name = window_name
    24. self.frame_processing_fn = frame_processing_fn
    25. self.cap = cv2.VideoCapture(capture)
    26. if not self.cap.isOpened():
    27. raise ValueError("Could not open video capture device")
    28. self._fps_counter = FPSCounter(update_frequency=fps_update_frequency)
    29. def run(self) -> None:
    30. """Start streaming video from the webcam and displaying it in a window.
    31. Press 'q' to quit the streaming.
    32. """
    33. while not self._stop():
    34. self._display_single_frame()
    35. def _display_single_frame(self) -> None:
    36. """Read a single frame from the video capture device, apply any specified frame processing,
    37. and display the resulting frame in the window.
    38. Also updates the FPS counter and displays it in the frame.
    39. """
    40. _ret, frame = self.cap.read()
    41. if self.frame_processing_fn:
    42. frame = self.frame_processing_fn(frame)
    43. _write_fps_to_frame(frame, self.fps)
    44. cv2.imshow(self.window_name, frame)
    45. def _stop(self) -> bool:
    46. """Stopping condition for the streaming."""
    47. return cv2.waitKey(1) & 0xFF == ord("q")
    48. @property
    49. def fps(self) -> float:
    50. return self._fps_counter.fps
    51. def __del__(self):
    52. """Release the video capture device and close the window."""
    53. self.cap.release()
    54. cv2.destroyAllWindows()
    55. def _write_fps_to_frame(frame: np.ndarray, fps: float) -> None:
    56. """Write the current FPS value on the given frame.
    57. :param frame: Frame to write the FPS value on.
    58. :param fps: Current FPS value to write.
    59. """
    60. font = cv2.FONT_HERSHEY_SIMPLEX
    61. font_scale = 0.6
    62. font_color = (0, 255, 0)
    63. line_type = 2
    64. cv2.putText(frame, "FPS: {:.2f}".format(fps), (10, 30), font, font_scale, font_color, line_type)
    65. class FPSCounter:
    66. """Class for calculating the FPS of a video stream."""
    67. def __init__(self, update_frequency: Optional[float] = None):
    68. """Create a new FPSCounter object.
    69. :param update_frequency: Minimum time (in seconds) between updates to the FPS counter.
    70. If None, the counter is updated every frame.
    71. """
    72. self._update_frequency = update_frequency
    73. self._start_time = time.time()
    74. self._frame_count = 0
    75. self._fps = 0.0
    76. def _update_fps(self, elapsed_time, current_time) -> None:
    77. """Compute new value of FPS and reset the counter."""
    78. self._fps = self._frame_count / elapsed_time
    79. self._start_time = current_time
    80. self._frame_count = 0
    81. @property
    82. def fps(self) -> float:
    83. """Current FPS value."""
    84. self._frame_count += 1
    85. current_time, elapsed_time = time.time(), time.time() - self._start_time
    86. if self._update_frequency is None or elapsed_time > self._update_frequency:
    87. self._update_fps(elapsed_time=elapsed_time, current_time=current_time)
    88. return self._fps
    Discard
    @@ -4,7 +4,7 @@ import cv2
     import numpy as np
     import numpy as np
     
     
     
     
    -__all__ = ["load_video", "save_video"]
    +__all__ = ["load_video", "save_video", "is_video", "show_video_from_disk", "show_video_from_frames"]
     
     
     
     
     def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
     def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
    @@ -30,6 +30,7 @@ def _open_video(file_path: str) -> cv2.VideoCapture:
         :return:            Opened video capture object
         :return:            Opened video capture object
         """
         """
         cap = cv2.VideoCapture(file_path)
         cap = cv2.VideoCapture(file_path)
    +
         if not cap.isOpened():
         if not cap.isOpened():
             raise ValueError(f"Failed to open video file: {file_path}")
             raise ValueError(f"Failed to open video file: {file_path}")
         return cap
         return cap
    @@ -97,3 +98,60 @@ def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
             raise RuntimeError("Your frames must include 3 channels.")
             raise RuntimeError("Your frames must include 3 channels.")
     
     
         return max_height, max_width
         return max_height, max_width
    +
    +
    +def show_video_from_disk(video_path: str, window_name: str = "Prediction"):
    +    """Display a video from disk using OpenCV.
    +
    +    :param video_path:   Path to the video file.
    +    :param window_name:  Name of the window to display the video
    +    """
    +    cap = _open_video(video_path)
    +    fps = cap.get(cv2.CAP_PROP_FPS)
    +
    +    while cap.isOpened():
    +        ret, frame = cap.read()
    +
    +        if ret:
    +            # Display the frame
    +            cv2.imshow(window_name, frame)
    +
    +            # Wait for the specified number of milliseconds before displaying the next frame
    +            if cv2.waitKey(int(1000 / fps)) & 0xFF == ord("q"):
    +                break
    +        else:
    +            break
    +
    +    # Release the VideoCapture object and destroy the window
    +    cap.release()
    +    cv2.destroyAllWindows()
    +    cv2.waitKey(1)
    +
    +
    +def show_video_from_frames(frames: List[np.ndarray], fps: float, window_name: str = "Prediction") -> None:
    +    """Display a video from a list of frames using OpenCV.
    +
    +    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    +    :param fps:         Frames per second
    +    :param window_name:  Name of the window to display the video
    +    """
    +    for frame in frames:
    +        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    +        cv2.imshow(window_name, frame)
    +        cv2.waitKey(int(1000 / fps))
    +    cv2.destroyAllWindows()
    +    cv2.waitKey(1)
    +
    +
    +def is_video(file_path: str) -> bool:
    +    """Check if a file is a video file.
    +    :param file_path:   Path to the video file.
    +    :return:            True if the file is a video file, False otherwise.
    +    """
    +    try:
    +        cap = cv2.VideoCapture(file_path, apiPreference=cv2.CAP_FFMPEG)
    +        if cap.isOpened():
    +            cap.release()
    +            return True
    +    except Exception:
    +        return False
    Discard
    @@ -1,18 +1,19 @@
    +import os
    +import tarfile
    +import re
     import math
     import math
     import time
     import time
     from functools import lru_cache
     from functools import lru_cache
     from pathlib import Path
     from pathlib import Path
    -from typing import Mapping, Optional, Tuple, Union, List, Dict, Any
    +from typing import Mapping, Optional, Tuple, Union, List, Dict, Any, Iterable
     from zipfile import ZipFile
     from zipfile import ZipFile
    -import os
     from jsonschema import validate
     from jsonschema import validate
    -import tarfile
    +from itertools import islice
    +
     from PIL import Image, ExifTags
     from PIL import Image, ExifTags
    -import re
     import torch
     import torch
     import torch.nn as nn
     import torch.nn as nn
     
     
    -
     # These functions changed from torch 1.2 to torch 1.3
     # These functions changed from torch 1.2 to torch 1.3
     
     
     import random
     import random
    @@ -526,3 +527,14 @@ def override_default_params_without_nones(params: Dict, default_params: Mapping)
             if key not in params.keys() or params[key] is None:
             if key not in params.keys() or params[key] is None:
                 params[key] = val
                 params[key] = val
         return params
         return params
    +
    +
    +def generate_batch(iterable: Iterable, batch_size: int) -> Iterable:
    +    """Batch data into tuples of length n. The last batch may be shorter."""
    +    it = iter(iterable)
    +    while True:
    +        batch = tuple(islice(it, batch_size))
    +        if batch:
    +            yield batch
    +        else:
    +            return
    Discard