Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

pose_estimation_dataset_test.py 930 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from super_gradients.training.datasets.pose_estimation_datasets import DEKRTargetsGenerator
  5. class TestPoseEstimationDataset(unittest.TestCase):
  6. def test_dekr_target_generator(self):
  7. target_generator = DEKRTargetsGenerator(
  8. output_stride=4,
  9. sigma=2,
  10. center_sigma=4,
  11. bg_weight=0.1,
  12. offset_radius=4,
  13. )
  14. joints = np.random.randint(0, 255, (4, 17, 3))
  15. joints[:, :, 2] = 1
  16. heatmaps, mask, offset_map, offset_weight = target_generator(
  17. image=torch.zeros((3, 256, 256)),
  18. joints=joints,
  19. mask=np.ones((256, 256)),
  20. )
  21. self.assertEqual(heatmaps.shape, (18, 64, 64))
  22. self.assertEqual(mask.shape, (18, 64, 64))
  23. self.assertEqual(offset_map.shape, (34, 64, 64))
  24. self.assertEqual(offset_weight.shape, (34, 64, 64))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...