Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#869 Add DagsHub Logger to Super Gradients

Merged
Ghost merged 1 commits into Deci-AI:master from timho102003:dagshub_logger
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  1. import unittest
  2. import torch
  3. from super_gradients.training.datasets.datasets_utils import DetectionMultiscalePrePredictionCallback
  4. class MultiScaleTest(unittest.TestCase):
  5. def setUp(self) -> None:
  6. self.size = (1024, 512)
  7. self.batch_size = 12
  8. self.change_frequency = 10
  9. self.multiscale_callback = DetectionMultiscalePrePredictionCallback(change_frequency=self.change_frequency)
  10. def _create_batch(self):
  11. inputs = torch.rand((self.batch_size, 3, self.size[0], self.size[1])) * 255
  12. targets = torch.cat([torch.tensor([[[0, 0, 10, 10, 0]]]) for _ in range(self.batch_size)], 0)
  13. return inputs, targets
  14. def test_multiscale_keep_state(self):
  15. """Check that the multiscale keeps in memory the new size to use between the size swaps"""
  16. for i in range(5):
  17. post_multiscale_input_shapes = []
  18. for j in range(self.change_frequency):
  19. inputs, targets = self._create_batch()
  20. post_multiscale_input, _ = self.multiscale_callback(inputs, targets, batch_idx=i * self.change_frequency + j)
  21. post_multiscale_input_shapes.append(list(post_multiscale_input.shape))
  22. # The shape should be the same for a given between k * self.change_frequency and (k+1)*self.change_frequency
  23. self.assertListEqual(post_multiscale_input_shapes[0], post_multiscale_input_shapes[-1])
  24. if __name__ == "__main__":
  25. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file