Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#224 Feature/sg 13 forward pass prep fn

Merged
Shay Aharon merged 1 commits into Deci-AI:master from deci-ai:feature/SG-13_forward_pass_prep_fn
@@ -160,6 +160,7 @@ class SgModel:
         self.scaler = None
         self.scaler = None
         self.phase_callbacks = None
         self.phase_callbacks = None
         self.checkpoint_params = None
         self.checkpoint_params = None
+        self.pre_prediction_callback = None
 
 
         # SET THE DEFAULT PROPERTIES
         # SET THE DEFAULT PROPERTIES
         self.half_precision = False
         self.half_precision = False
@@ -364,11 +365,15 @@ class SgModel:
                                criterion=self.criterion,
                                criterion=self.criterion,
                                device=self.device,
                                device=self.device,
                                lr_warmup_epochs=self.training_params.lr_warmup_epochs,
                                lr_warmup_epochs=self.training_params.lr_warmup_epochs,
-                               sg_logger=self.sg_logger)
+                               sg_logger=self.sg_logger,
+                               train_loader=self.train_loader)
 
 
         for batch_idx, batch_items in enumerate(progress_bar_train_loader):
         for batch_idx, batch_items in enumerate(progress_bar_train_loader):
             batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
             batch_items = core_utils.tensor_container_to_device(batch_items, self.device, non_blocking=True)
             inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
             inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
+
+            if self.pre_prediction_callback is not None:
+                inputs, targets = self.pre_prediction_callback(inputs, targets, batch_idx)
             # AUTOCAST IS ENABLED ONLY IF self.training_params.mixed_precision - IF enabled=False AUTOCAST HAS NO EFFECT
             # AUTOCAST IS ENABLED ONLY IF self.training_params.mixed_precision - IF enabled=False AUTOCAST HAS NO EFFECT
             with autocast(enabled=self.training_params.mixed_precision):
             with autocast(enabled=self.training_params.mixed_precision):
                 # FORWARD PASS TO GET NETWORK'S PREDICTIONS
                 # FORWARD PASS TO GET NETWORK'S PREDICTIONS
@@ -757,6 +762,13 @@ class SgModel:
 
 
                     Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).
                     Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).
 
 
+                -   `pre_prediction_callback` : Callable (default=None)
+
+                     When not None, this callback will be applied to images and targets, and returning them to be used
+                      for the forward pass, and further computations. Args for this callable should be in the order
+                      (inputs, targets, batch_idx) returning modified_inputs, modified_targets
+
+
 
 
         :return:
         :return:
         """
         """
@@ -908,11 +920,20 @@ class SgModel:
         if self.load_checkpoint and load_opt_params:
         if self.load_checkpoint and load_opt_params:
             self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
             self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
 
 
+        self.pre_prediction_callback = self.training_params.pre_prediction_callback
+
         self._initialize_mixed_precision(self.training_params.mixed_precision)
         self._initialize_mixed_precision(self.training_params.mixed_precision)
 
 
-        context = PhaseContext(optimizer=self.optimizer, net=self.net, experiment_name=self.experiment_name,
-                               ckpt_dir=self.checkpoints_dir_path, criterion=self.criterion,
-                               lr_warmup_epochs=self.training_params.lr_warmup_epochs, sg_logger=self.sg_logger)
+        context = PhaseContext(optimizer=self.optimizer,
+                               net=self.net,
+                               experiment_name=self.experiment_name,
+                               ckpt_dir=self.checkpoints_dir_path,
+                               criterion=self.criterion,
+                               lr_warmup_epochs=self.training_params.lr_warmup_epochs,
+                               sg_logger=self.sg_logger,
+                               train_loader=self.train_loader,
+                               valid_loader=self.valid_loader)
+
         self.phase_callback_handler(Phase.PRE_TRAINING, context)
         self.phase_callback_handler(Phase.PRE_TRAINING, context)
 
 
         try:
         try:
Discard
Tip!

Press p or to see the previous file or, n or to see the next file