Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_yolo_nas_pose.py 4.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  1. import unittest
  2. import torch
  3. from super_gradients.common.object_names import Models
  4. from super_gradients.training import models
  5. from super_gradients.training.datasets.pose_estimation_datasets.yolo_nas_pose_collate_fn import (
  6. flat_collate_tensors_with_batch_index,
  7. undo_flat_collate_tensors_with_batch_index,
  8. )
  9. from super_gradients.training.losses import YoloNASPoseLoss
  10. class YoloNASPoseTests(unittest.TestCase):
  11. def test_yolo_nas_pose_forward(self):
  12. num_joints = 33
  13. model = models.get(Models.YOLO_NAS_POSE_N, num_classes=num_joints).eval()
  14. input = torch.randn((1, 3, 640, 640))
  15. decoded_predictions, _ = model(input)
  16. pred_bboxes, pred_scores, pred_pose_coords, pred_pose_scores = decoded_predictions
  17. self.assertEquals(pred_bboxes.shape[2], 4)
  18. self.assertEquals(pred_scores.shape[2], 1)
  19. self.assertEquals(pred_pose_coords.shape[2], num_joints)
  20. self.assertEquals(pred_pose_coords.shape[3], 2)
  21. self.assertEquals(pred_pose_scores.shape[2], num_joints)
  22. def test_yolo_nas_pose_loss_function(self):
  23. model = models.get(Models.YOLO_NAS_POSE_N, num_classes=17)
  24. input = torch.randn((3, 3, 640, 640))
  25. outputs = model(input)
  26. criterion = YoloNASPoseLoss(
  27. oks_sigmas=[0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089],
  28. )
  29. # A single tensor of shape (N, 1 + 4 + Num Joints * 3) (batch_index, x1, y1, x2, y2, [x, y, visibility] * Num Joints)
  30. # First image has 1 object, second image has 2 objects, third image has no objects
  31. target_boxes = flat_collate_tensors_with_batch_index(
  32. [
  33. torch.tensor([[10, 10, 100, 200]]),
  34. torch.tensor([[300, 500, 400, 550], [200, 200, 400, 400]]),
  35. torch.zeros((0, 4)),
  36. ]
  37. ).float()
  38. target_poses = flat_collate_tensors_with_batch_index(
  39. [
  40. torch.randn((1, 17, 3)), # First image has 1 object
  41. torch.randn((2, 17, 3)), # Second image has 2 objects
  42. torch.zeros((0, 17, 3)), # Third image has no objects
  43. ]
  44. ).float()
  45. target_poses[..., 3] = 2.0 # Mark all joints as visible
  46. target_crowds = flat_collate_tensors_with_batch_index([torch.zeros((1, 1)), torch.zeros((2, 1)), torch.zeros((0, 1))]).float()
  47. targets = (target_boxes, target_poses, target_crowds)
  48. loss = criterion(outputs=outputs, targets=targets)
  49. loss[0].backward()
  50. def test_flat_collate_2d(self):
  51. values = [
  52. torch.randn([1, 4]),
  53. torch.randn([2, 4]),
  54. torch.randn([0, 4]),
  55. torch.randn([3, 4]),
  56. ]
  57. flat_tensor = flat_collate_tensors_with_batch_index(values)
  58. undo_values = undo_flat_collate_tensors_with_batch_index(flat_tensor, 4)
  59. assert len(undo_values) == len(values)
  60. assert (undo_values[0] == values[0]).all()
  61. assert (undo_values[1] == values[1]).all()
  62. assert (undo_values[2] == values[2]).all()
  63. assert (undo_values[3] == values[3]).all()
  64. def test_flat_collate_3d(self):
  65. values = [
  66. torch.randn([1, 17, 3]),
  67. torch.randn([2, 17, 3]),
  68. torch.randn([0, 17, 3]),
  69. torch.randn([3, 17, 3]),
  70. ]
  71. flat_tensor = flat_collate_tensors_with_batch_index(values)
  72. undo_values = undo_flat_collate_tensors_with_batch_index(flat_tensor, 4)
  73. assert len(undo_values) == len(values)
  74. assert (undo_values[0] == values[0]).all()
  75. assert (undo_values[1] == values[1]).all()
  76. assert (undo_values[2] == values[2]).all()
  77. assert (undo_values[3] == values[3]).all()
  78. def test_yolo_nas_pose_replace_classes(self):
  79. model = models.get(Models.YOLO_NAS_POSE_N, num_classes=17)
  80. model.replace_head(new_num_classes=20)
  81. input = torch.randn((1, 3, 640, 640))
  82. decoded_predictions, _ = model(input)
  83. pred_bboxes, pred_scores, pred_pose_coords, pred_pose_scores = decoded_predictions
  84. self.assertEqual(pred_pose_coords.shape[2], 20)
  85. self.assertEqual(pred_pose_scores.shape[2], 20)
  86. if __name__ == "__main__":
  87. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...