Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#869 Add DagsHub Logger to Super Gradients

Merged
Ghost merged 1 commits into Deci-AI:master from timho102003:dagshub_logger
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from super_gradients.training.losses.dekr_loss import DEKRLoss
  5. from super_gradients.training.datasets.pose_estimation_datasets.target_generators import DEKRTargetsGenerator
  6. class DEKRLossTest(unittest.TestCase):
  7. def test_dekr_loss(self):
  8. num_joints = 17
  9. num_persons = 3
  10. target_generator = DEKRTargetsGenerator(output_stride=4, sigma=2, center_sigma=4, bg_weight=0.1, offset_radius=4)
  11. joints = np.random.randint(1, 255, (num_persons, num_joints, 3))
  12. image = torch.randn((3, 256, 256))
  13. mask = np.ones((256, 256))
  14. joints[:, :, 2] = 1 # All visible
  15. targets = target_generator(image, joints, mask)
  16. gt_heatmaps, gt_mask, gt_offset_map, gt_offset_weight = targets
  17. self.assertEqual(
  18. gt_heatmaps.shape, (num_joints + 1, image.shape[1] // target_generator.output_stride, image.shape[2] // target_generator.output_stride)
  19. )
  20. random_predictions = torch.randn(
  21. (1, num_joints + 1, image.shape[1] // target_generator.output_stride, image.shape[2] // target_generator.output_stride)
  22. ), torch.randn((1, num_joints * 2, image.shape[1] // target_generator.output_stride, image.shape[2] // target_generator.output_stride))
  23. targets = (
  24. torch.from_numpy(gt_heatmaps).unsqueeze(0),
  25. torch.from_numpy(gt_mask).unsqueeze(0),
  26. torch.from_numpy(gt_offset_map).unsqueeze(0),
  27. torch.from_numpy(gt_offset_weight).unsqueeze(0),
  28. )
  29. loss = DEKRLoss()
  30. main_loss, loss_components = loss(random_predictions, targets)
  31. self.assertEqual(len(loss_components), len(loss.component_names))
  32. perfect_predictions = targets[0], targets[2]
  33. main_loss, loss_components = loss(perfect_predictions, targets)
  34. self.assertEqual(main_loss.item(), 0)
  35. if __name__ == "__main__":
  36. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file