Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

average_meter_test.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  1. import torch
  2. import unittest
  3. from super_gradients.training.utils.utils import AverageMeter
  4. class TestAverageMeter(unittest.TestCase):
  5. """Test the behavior of the class is not changed since several parts of the code rely on it"""
  6. @classmethod
  7. def setUp(cls):
  8. cls.avg_float = AverageMeter()
  9. cls.avg_tuple = AverageMeter()
  10. cls.avg_list = AverageMeter()
  11. cls.avg_tensor = AverageMeter()
  12. cls.left_empty = AverageMeter()
  13. cls.list_of_avg_meter = [cls.avg_float, cls.avg_tuple, cls.avg_list, cls.avg_tensor]
  14. cls.score_types = [1.2, (3., 4.), [5., 6., 7.], torch.FloatTensor([8., 9., 10.])]
  15. cls.batch_size = 3
  16. def test_empty_return_0(self):
  17. self.assertEqual(self.left_empty.average, 0)
  18. def test_correctness_and_typing(self):
  19. # VERIFY THE VALUES ARE INITIALIZED TO None & 0 FOR THE ABILITY TO USE ANY TYPE OF value
  20. self.assertIsNone(self.avg_float._sum)
  21. self.assertEqual(self.avg_float._count, 0)
  22. # RUN OVER THE DIFFERENT TYPES OF avg_meters AND VERIFY FOR EACH
  23. for list_idx, (avg_meter, score) in enumerate(zip(self.list_of_avg_meter, self.score_types)):
  24. # ADD THE VALUES 3 TIMES
  25. for repetition in [1, 2, 3]:
  26. avg_meter.update(score, self.batch_size)
  27. # VERIFY VALUES ARE CORRECT AND OF THE EXPECTED TYPES
  28. self.assertEqual(avg_meter._count, self.batch_size * repetition)
  29. if list_idx == 0: # FOR avg_float
  30. self.assertEqual(avg_meter._count, self.batch_size * repetition)
  31. self.assertIsInstance(avg_meter.average, float)
  32. self.assertAlmostEqual(avg_meter.average, score)
  33. else: # FOR ALL THE OTHERS
  34. self.assertListEqual(list(avg_meter._sum), [val * self.batch_size * repetition for val in score])
  35. self.assertIsInstance(avg_meter.average, tuple)
  36. self.assertListEqual(list(avg_meter.average), list(score))
  37. if __name__ == '__main__':
  38. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...