Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#224 Feature/sg 13 forward pass prep fn

Merged
Shay Aharon merged 1 commits into Deci-AI:master from deci-ai:feature/SG-13_forward_pass_prep_fn
@@ -25,6 +25,7 @@ from tqdm import tqdm
 from super_gradients.training.utils.utils import AverageMeter
 from super_gradients.training.utils.utils import AverageMeter
 from super_gradients.training.utils.detection_utils import DetectionVisualization
 from super_gradients.training.utils.detection_utils import DetectionVisualization
 import uuid
 import uuid
+from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
 
 
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 
 
@@ -191,6 +192,123 @@ class MultiScaleCollateFunction(AbstractCollateFunction):
             return images, batch[1]
             return images, batch[1]
 
 
 
 
+class AbstractPrePredictionCallback(ABC):
+    """
+    Abstract class for forward pass preprocessing function, to be used by passing its inheritors through training_params
+     pre_prediction_callback keyword arg.
+
+    Should implement __call__ and return images, targets after applying the desired preprocessing.
+    """
+    @abstractmethod
+    def __call__(self, inputs, targets, batch_idx):
+        pass
+
+
+class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
+    """
+    Mutiscale pre-prediction callback pass function.
+
+    When passed through train_params images, targets will be applied by the below transform to support multi scaling
+    on the fly.
+
+    After each self.frequency forward passes, change size randomly from
+     (input_size-self.multiscale_range*self.image_size_steps, input_size-(self.multiscale_range-1)*self.image_size_steps,
+     ...input_size+self.multiscale_range*self.image_size_steps)
+
+
+    Attributes:
+        multiscale_range: (int) Range of values for resize sizes as discussed above (default=5)
+        image_size_steps: (int) Image step sizes as discussed abov (default=32)
+        change_frequency: (int) The frequency to apply change in input size.
+    """
+
+    def __init__(self, multiscale_range: int = 5,
+                 image_size_steps: int = 32,
+                 change_frequency: int = 10):
+
+        self.multiscale_range = multiscale_range
+        self.image_size_steps = image_size_steps
+        self.frequency = change_frequency
+        self.rank = None
+        self.is_distributed = None
+
+    def __call__(self, inputs, targets, batch_idx):
+        if self.rank is None:
+            self.rank = get_local_rank()
+        if self.is_distributed is None:
+            self.is_distributed = get_world_size() > 1
+
+        # GENERATE A NEW SIZE AND BROADCAST IT TO THE THE OTHER RANKS SO THEY HAVE THE SAME SCALE
+        if batch_idx % self.frequency == 0:
+            tensor = torch.LongTensor(2).cuda()
+            input_size = inputs.shape[2:]
+
+            if self.rank == 0:
+                size_factor = input_size[1] * 1.0 / input_size[0]
+                min_size = int(input_size[0] / self.image_size_steps) - self.multiscale_range
+                max_size = int(input_size[0] / self.image_size_steps) + self.multiscale_range
+                random_size = (min_size, max_size)
+                size = random.randint(*random_size)
+                size = (int(self.image_size_steps * size), self.image_size_steps * int(size * size_factor))
+                tensor[0] = size[0]
+                tensor[1] = size[1]
+
+            if self.is_distributed:
+                dist.barrier()
+                dist.broadcast(tensor, 0)
+
+            new_input_size = (tensor[0].item(), tensor[1].item())
+
+            scale_y = new_input_size[0] / input_size[0]
+            scale_x = new_input_size[1] / input_size[1]
+            if scale_x != 1 or scale_y != 1:
+                inputs = torch.nn.functional.interpolate(inputs, size=new_input_size, mode="bilinear", align_corners=False)
+        return inputs, targets
+
+
+class DetectionMultiscalePrePredictionCallback(MultiscalePrePredictionCallback):
+    """
+    Mutiscalepre-prediction callback for object detection.
+
+    When passed through train_params images, targets will be applied by the below transform to support multi scaling
+    on the fly.
+
+    After each self.frequency forward passes, change size randomly from
+     (input_size-self.multiscale_range*self.image_size_steps, input_size-(self.multiscale_range-1)*self.image_size_steps,
+     ...input_size+self.multiscale_range*self.image_size_steps) and apply the same rescaling to the box coordinates.
+
+
+
+    Attributes:
+        multiscale_range: (int) Range of values for resize sizes as discussed above (default=5)
+        image_size_steps: (int) Image step sizes as discussed abov (default=32)
+        change_frequency: (int) The frequency to apply change in input size.
+
+    """
+
+    def __init__(self, multiscale_range: int = 5,
+                 image_size_steps: int = 32,
+                 change_frequency: int = 10):
+
+        self.multiscale_range = multiscale_range
+        self.image_size_steps = image_size_steps
+        self.frequency = change_frequency
+        self.rank = None
+        self.is_distributed = None
+
+    def __call__(self, inputs, targets, batch_idx):
+        # RESCALE THE IMAGE FIRST WITH SUPER(), AND IF RESCALING HAS ACTUALLY BEEN DONE APPLY TO BOXES AS WELL
+        input_size = inputs.shape[2:]
+        inputs, targets = super(DetectionMultiscalePrePredictionCallback, self).__call__(inputs, targets, batch_idx)
+        new_input_size = inputs.shape[2:]
+        scale_y = new_input_size[0] / input_size[0]
+        scale_x = new_input_size[1] / input_size[1]
+        if scale_x != 1 or scale_y != 1:
+            targets[..., 2::2] = targets[..., 2::2] * scale_x
+            targets[..., 3::2] = targets[..., 3::2] * scale_y
+        return inputs, targets
+
+
 _pil_interpolation_to_str = {
 _pil_interpolation_to_str = {
     Image.NEAREST: 'PIL.Image.NEAREST',
     Image.NEAREST: 'PIL.Image.NEAREST',
     Image.BILINEAR: 'PIL.Image.BILINEAR',
     Image.BILINEAR: 'PIL.Image.BILINEAR',
Discard
@@ -47,7 +47,8 @@ DEFAULT_TRAINING_PARAMS = {"lr_warmup_epochs": 0,
                            "warmup_mode": "linear_step",
                            "warmup_mode": "linear_step",
                            "step_lr_update_freq": None,
                            "step_lr_update_freq": None,
                            "lr_updates": [],
                            "lr_updates": [],
-                           'clip_grad_norm': None
+                           'clip_grad_norm': None,
+                           'pre_prediction_callback': None
                            }
                            }
 
 
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
 DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
Discard
@@ -160,6 +160,7 @@ class SgModel:
         self.scaler = None
         self.scaler = None
         self.phase_callbacks = None
         self.phase_callbacks = None
         self.checkpoint_params = None
         self.checkpoint_params = None
+        self.pre_prediction_callback = None
 
 
         # SET THE DEFAULT PROPERTIES
         # SET THE DEFAULT PROPERTIES
         self.half_precision = False
         self.half_precision = False
@@ -364,11 +365,15 @@ class SgModel:
                                criterion=self.criterion,
                                criterion=self.criterion,
                                device=self.device,
                                device=self.device,
                                lr_warmup_epochs=self.training_params.lr_warmup_epochs,
                                lr_warmup_epochs=self.training_params.lr_warmup_epochs,
-                               sg_logger=self.sg_logger)
+                               sg_logger=self.sg_logger,
+                               train_loader=self.train_loader)
 
 
         for batch_idx, batch_items in enumerate(progress_bar_train_loader):
         for batch_idx, batch_items in enumerate(progress_bar_train_loader):
             batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
             batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
             inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
             inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
+
+            if self.pre_prediction_callback is not None:
+                inputs, targets = self.pre_prediction_callback(inputs, targets, batch_idx)
             # AUTOCAST IS ENABLED ONLY IF self.training_params.mixed_precision - IF enabled=False AUTOCAST HAS NO EFFECT
             # AUTOCAST IS ENABLED ONLY IF self.training_params.mixed_precision - IF enabled=False AUTOCAST HAS NO EFFECT
             with autocast(enabled=self.training_params.mixed_precision):
             with autocast(enabled=self.training_params.mixed_precision):
                 # FORWARD PASS TO GET NETWORK'S PREDICTIONS
                 # FORWARD PASS TO GET NETWORK'S PREDICTIONS
@@ -757,6 +762,13 @@ class SgModel:
 
 
                     Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).
                     Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).
 
 
+                -   `pre_prediction_callback` : Callable (default=None)
+
+                     When not None, this callback will be applied to images and targets, and returning them to be used
+                      for the forward pass, and further computations. Args for this callable should be in the order
+                      (inputs, targets, batch_idx) returning modified_inputs, modified_targets
+
+
 
 
         :return:
         :return:
         """
         """
@@ -908,11 +920,20 @@ class SgModel:
         if self.load_checkpoint and load_opt_params:
         if self.load_checkpoint and load_opt_params:
             self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
             self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
 
 
+        self.pre_prediction_callback = self.training_params.pre_prediction_callback
+
         self._initialize_mixed_precision(self.training_params.mixed_precision)
         self._initialize_mixed_precision(self.training_params.mixed_precision)
 
 
-        context = PhaseContext(optimizer=self.optimizer, net=self.net, experiment_name=self.experiment_name,
-                               ckpt_dir=self.checkpoints_dir_path, criterion=self.criterion,
-                               lr_warmup_epochs=self.training_params.lr_warmup_epochs, sg_logger=self.sg_logger)
+        context = PhaseContext(optimizer=self.optimizer,
+                               net=self.net,
+                               experiment_name=self.experiment_name,
+                               ckpt_dir=self.checkpoints_dir_path,
+                               criterion=self.criterion,
+                               lr_warmup_epochs=self.training_params.lr_warmup_epochs,
+                               sg_logger=self.sg_logger,
+                               train_loader=self.train_loader,
+                               valid_loader=self.valid_loader)
+
         self.phase_callback_handler(Phase.PRE_TRAINING, context)
         self.phase_callback_handler(Phase.PRE_TRAINING, context)
 
 
         try:
         try:
Discard
@@ -126,6 +126,18 @@ def get_local_rank():
     return dist.get_rank() if dist.is_initialized() else 0
     return dist.get_rank() if dist.is_initialized() else 0
 
 
 
 
+def get_world_size() -> int:
+    """
+    Returns the world size if running in DDP, and 1 otherwise
+    :return: world size
+    """
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size()
+
+
 @contextmanager
 @contextmanager
 def wait_for_the_master(local_rank: int):
 def wait_for_the_master(local_rank: int):
     """
     """
Discard
@@ -19,6 +19,7 @@ from tests.unit_tests.vit_unit_test import TestViT
 from tests.unit_tests.yolox_unit_test import TestYOLOX
 from tests.unit_tests.yolox_unit_test import TestYOLOX
 from tests.unit_tests.lr_cooldown_test import LRCooldownTest
 from tests.unit_tests.lr_cooldown_test import LRCooldownTest
 from tests.unit_tests.detection_targets_format_transform_test import DetectionTargetsTransformTest
 from tests.unit_tests.detection_targets_format_transform_test import DetectionTargetsTransformTest
+from tests.unit_tests.forward_pass_prep_fn_test import ForwardpassPrepFNTest
 
 
 
 
 class CoreUnitTestSuiteRunner:
 class CoreUnitTestSuiteRunner:
@@ -61,6 +62,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(InitializeWithDataloadersTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LRCooldownTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LRCooldownTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionTargetsTransformTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionTargetsTransformTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ForwardpassPrepFNTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  1. import unittest
  2. from super_gradients.training import SgModel
  3. from super_gradients.training.metrics import Accuracy
  4. from super_gradients.training.datasets import ClassificationTestDatasetInterface
  5. from super_gradients.training.utils.callbacks import PhaseCallback, Phase, PhaseContext
  6. import torch
  7. class TestInputSizesCallback(PhaseCallback):
  8. """
  9. Phase callback that collects the input shapes rates in lr_placeholder at the end of each forward pass.
  10. """
  11. def __init__(self, shapes_placeholder):
  12. super(TestInputSizesCallback, self).__init__(Phase.TRAIN_BATCH_END)
  13. self.shapes_placeholder = shapes_placeholder
  14. def __call__(self, context: PhaseContext):
  15. self.shapes_placeholder.append(context.inputs.shape)
  16. def test_forward_pass_prep_fn(inputs, targets, *args, **kwargs):
  17. inputs = torch.nn.functional.interpolate(
  18. inputs, size=(50, 50), mode="bilinear", align_corners=False
  19. )
  20. return inputs, targets
  21. class ForwardpassPrepFNTest(unittest.TestCase):
  22. def setUp(self) -> None:
  23. self.dataset_params = {"batch_size": 4}
  24. self.dataset = ClassificationTestDatasetInterface(dataset_params=self.dataset_params)
  25. self.arch_params = {'num_classes': 10}
  26. def test_resizing_with_forward_pass_prep_fn(self):
  27. # Define Model
  28. model = SgModel("ForwardpassPrepFNTest")
  29. model.connect_dataset_interface(self.dataset)
  30. model.build_model("resnet18", arch_params=self.arch_params)
  31. sizes = []
  32. phase_callbacks = [TestInputSizesCallback(sizes)]
  33. train_params = {"max_epochs": 2, "cosine_final_lr_ratio": 0.2, "lr_mode": "cosine",
  34. "lr_cooldown_epochs": 2,
  35. "lr_warmup_epochs": 3, "initial_lr": 1, "loss": "cross_entropy", "optimizer": 'SGD',
  36. "criterion_params": {}, "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  37. "train_metrics_list": [Accuracy()], "valid_metrics_list": [Accuracy()],
  38. "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
  39. "greater_metric_to_watch_is_better": True, "ema": False, "phase_callbacks": phase_callbacks,
  40. "pre_prediction_callback": test_forward_pass_prep_fn}
  41. model.train(train_params)
  42. # ALTHOUGH NOT SEEN IN HERE, THE 4TH EPOCH USES LR=1, SO THIS IS THE EXPECTED LIST AS WE COLLECT
  43. # THE LRS AFTER THE UPDATE
  44. sizes = list(map(lambda size: size[2], sizes))
  45. self.assertTrue(all(map(lambda size: size == 50, sizes)))
Discard