Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#309 Fix scale between rescaling batches

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-221-make_multiscale_keep_state
@@ -26,6 +26,7 @@ from tests.unit_tests.forward_pass_prep_fn_test import ForwardpassPrepFNTest
 from tests.unit_tests.mask_loss_test import MaskAttentionLossTest
 from tests.unit_tests.mask_loss_test import MaskAttentionLossTest
 from tests.unit_tests.detection_sub_sampling_test import TestDetectionDatasetSubsampling
 from tests.unit_tests.detection_sub_sampling_test import TestDetectionDatasetSubsampling
 from tests.unit_tests.detection_sub_classing_test import TestDetectionDatasetSubclassing
 from tests.unit_tests.detection_sub_classing_test import TestDetectionDatasetSubclassing
+from tests.unit_tests.multi_scaling_test import MultiScaleTest
 
 
 
 
 class CoreUnitTestSuiteRunner:
 class CoreUnitTestSuiteRunner:
@@ -74,6 +75,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  1. import unittest
  2. import torch
  3. from super_gradients.training.datasets.datasets_utils import DetectionMultiscalePrePredictionCallback
  4. class MultiScaleTest(unittest.TestCase):
  5. def setUp(self) -> None:
  6. self.size = (1024, 512)
  7. self.batch_size = 12
  8. self.change_frequency = 10
  9. self.multiscale_callback = DetectionMultiscalePrePredictionCallback(change_frequency=self.change_frequency)
  10. def _create_batch(self):
  11. inputs = torch.rand((self.batch_size, 3, self.size[0], self.size[1])) * 255
  12. targets = torch.cat([torch.tensor([[[0, 0, 10, 10, 0]]]) for _ in range(self.batch_size)], 0)
  13. return inputs, targets
  14. def test_multiscale_keep_state(self):
  15. """Check that the multiscale keeps in memory the new size to use between the size swaps"""
  16. for i in range(5):
  17. post_multiscale_input_shapes = []
  18. for j in range(self.change_frequency):
  19. inputs, targets = self._create_batch()
  20. post_multiscale_input, _ = self.multiscale_callback(inputs, targets, batch_idx=i * self.change_frequency + j)
  21. post_multiscale_input_shapes.append(list(post_multiscale_input.shape))
  22. # The shape should be the same for a given between k * self.change_frequency and (k+1)*self.change_frequency
  23. self.assertListEqual(post_multiscale_input_shapes[0], post_multiscale_input_shapes[-1])
  24. if __name__ == '__main__':
  25. unittest.main()
Discard