Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#309 Fix scale between rescaling batches

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-221-make_multiscale_keep_state
@@ -231,6 +231,7 @@ class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
         self.rank = None
         self.rank = None
         self.is_distributed = None
         self.is_distributed = None
         self.sampled_imres_once = False
         self.sampled_imres_once = False
+        self.new_input_size = None
 
 
     def __call__(self, inputs, targets, batch_idx):
     def __call__(self, inputs, targets, batch_idx):
         if self.rank is None:
         if self.rank is None:
@@ -239,9 +240,9 @@ class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
             self.is_distributed = get_world_size() > 1
             self.is_distributed = get_world_size() > 1
 
 
         # GENERATE A NEW SIZE AND BROADCAST IT TO THE THE OTHER RANKS SO THEY HAVE THE SAME SCALE
         # GENERATE A NEW SIZE AND BROADCAST IT TO THE THE OTHER RANKS SO THEY HAVE THE SAME SCALE
+        input_size = inputs.shape[2:]
         if batch_idx % self.frequency == 0:
         if batch_idx % self.frequency == 0:
-            tensor = torch.LongTensor(2).cuda()
-            input_size = inputs.shape[2:]
+            tensor = torch.LongTensor(2).to(inputs.device)
 
 
             if self.rank == 0:
             if self.rank == 0:
                 size_factor = input_size[1] * 1.0 / input_size[0]
                 size_factor = input_size[1] * 1.0 / input_size[0]
@@ -262,12 +263,12 @@ class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
                 dist.barrier()
                 dist.barrier()
                 dist.broadcast(tensor, 0)
                 dist.broadcast(tensor, 0)
 
 
-            new_input_size = (tensor[0].item(), tensor[1].item())
+            self.new_input_size = (tensor[0].item(), tensor[1].item())
 
 
-            scale_y = new_input_size[0] / input_size[0]
-            scale_x = new_input_size[1] / input_size[1]
-            if scale_x != 1 or scale_y != 1:
-                inputs = torch.nn.functional.interpolate(inputs, size=new_input_size, mode="bilinear", align_corners=False)
+        scale_y = self.new_input_size[0] / input_size[0]
+        scale_x = self.new_input_size[1] / input_size[1]
+        if scale_x != 1 or scale_y != 1:
+            inputs = torch.nn.functional.interpolate(inputs, size=self.new_input_size, mode="bilinear", align_corners=False)
         return inputs, targets
         return inputs, targets
 
 
 
 
Discard
@@ -26,6 +26,7 @@ from tests.unit_tests.forward_pass_prep_fn_test import ForwardpassPrepFNTest
 from tests.unit_tests.mask_loss_test import MaskAttentionLossTest
 from tests.unit_tests.mask_loss_test import MaskAttentionLossTest
 from tests.unit_tests.detection_sub_sampling_test import TestDetectionDatasetSubsampling
 from tests.unit_tests.detection_sub_sampling_test import TestDetectionDatasetSubsampling
 from tests.unit_tests.detection_sub_classing_test import TestDetectionDatasetSubclassing
 from tests.unit_tests.detection_sub_classing_test import TestDetectionDatasetSubclassing
+from tests.unit_tests.multi_scaling_test import MultiScaleTest
 
 
 
 
 class CoreUnitTestSuiteRunner:
 class CoreUnitTestSuiteRunner:
@@ -74,6 +75,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. import unittest
  2. import torch
  3. from super_gradients.training.datasets.datasets_utils import DetectionMultiscalePrePredictionCallback
  4. class MultiScaleTest(unittest.TestCase):
  5. def setUp(self) -> None:
  6. self.size = (1024, 512)
  7. self.batch_size = 12
  8. self.change_frequency = 10
  9. self.multiscale_callback = DetectionMultiscalePrePredictionCallback(change_frequency=self.change_frequency)
  10. def _create_batch(self):
  11. inputs = torch.rand((self.batch_size, 3, self.size[0], self.size[1])) * 255
  12. targets = torch.cat([torch.tensor([[[0, 0, 10, 10, 0]]]) for _ in range(self.batch_size)], 0)
  13. return inputs, targets
  14. def test_multiscale_keep_state(self):
  15. """Check that the multiscale keeps in memory the new size to use between the size swaps"""
  16. for i in range(5):
  17. post_multiscale_input_shapes = []
  18. for j in range(self.change_frequency):
  19. inputs, targets = self._create_batch()
  20. post_multiscale_input, _ = self.multiscale_callback(inputs, targets, batch_idx=i * self.change_frequency + j)
  21. post_multiscale_input_shapes.append(list(post_multiscale_input.shape))
  22. # The shape should be the same for a given between k * self.change_frequency and (k+1)*self.change_frequency
  23. self.assertListEqual(post_multiscale_input_shapes[0], post_multiscale_input_shapes[-1])
  24. if __name__ == '__main__':
  25. unittest.main()
Discard