Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_poe_se3_criterion.py 1.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  1. import os
  2. import unittest
  3. import torch
  4. from lieposenet.criterions import POESE3Criterion
  5. class TestPOESE3Criterion(unittest.TestCase):
  6. def setUp(self) -> None:
  7. torch.autograd.set_detect_anomaly(True)
  8. self._criterion = POESE3Criterion()
  9. self._position = (torch.arange(1, self._criterion.position_dimension + 1, requires_grad=True,
  10. dtype=torch.float, device="cuda:0") * 0.3)[None]
  11. self._position = torch.repeat_interleave(self._position, 2, dim=0)
  12. self._position[:6] = 0
  13. self._target_position = torch.eye(4, 4, device="cuda:0")[None]
  14. self._target_position = torch.repeat_interleave(self._target_position, 2, dim=0)
  15. def test_forward(self):
  16. loss = self._criterion.forward(self._position, self._target_position)
  17. self.assertIsNotNone(loss)
  18. loss.backward()
  19. def test_translation(self):
  20. translation = self._criterion.translation(self._position)
  21. self.assertEqual(translation.shape, torch.Size([2, 3]))
  22. def test_rotation(self):
  23. rotation = self._criterion.rotation(self._position)
  24. self.assertEqual(rotation.shape, torch.Size([2, 4]))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...