Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_prototype_models.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  1. import pytest
  2. import test_models as TM
  3. import torch
  4. from common_utils import cpu_and_cuda, set_rng_seed
  5. from torchvision.prototype import models
  6. @pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,))
  7. @pytest.mark.parametrize("model_mode", ("standard", "scripted"))
  8. @pytest.mark.parametrize("dev", cpu_and_cuda())
  9. def test_raft_stereo(model_fn, model_mode, dev):
  10. # A simple test to make sure the model can do forward pass and jit scriptable
  11. set_rng_seed(0)
  12. # Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output
  13. # get the idea from test_models.test_raft
  14. corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2)
  15. corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2)
  16. model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev)
  17. if model_mode == "scripted":
  18. model = torch.jit.script(model)
  19. img1 = torch.rand(1, 3, 64, 64).to(dev)
  20. img2 = torch.rand(1, 3, 64, 64).to(dev)
  21. num_iters = 3
  22. preds = model(img1, img2, num_iters=num_iters)
  23. depth_pred = preds[-1]
  24. assert len(preds) == num_iters, "Number of predictions should be the same as model.num_iters"
  25. assert depth_pred.shape == torch.Size(
  26. [1, 1, 64, 64]
  27. ), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
  28. # Test against expected file output
  29. TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
  30. @pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,))
  31. @pytest.mark.parametrize("model_mode", ("standard", "scripted"))
  32. @pytest.mark.parametrize("dev", cpu_and_cuda())
  33. def test_crestereo(model_fn, model_mode, dev):
  34. set_rng_seed(0)
  35. model = model_fn().eval().to(dev)
  36. if model_mode == "scripted":
  37. model = torch.jit.script(model)
  38. img1 = torch.rand(1, 3, 64, 64).to(dev)
  39. img2 = torch.rand(1, 3, 64, 64).to(dev)
  40. iterations = 3
  41. preds = model(img1, img2, flow_init=None, num_iters=iterations)
  42. disparity_pred = preds[-1]
  43. # all the pyramid levels except the highest res make only half the number of iterations
  44. expected_iterations = (iterations // 2) * (len(model.resolutions) - 1)
  45. expected_iterations += iterations
  46. assert (
  47. len(preds) == expected_iterations
  48. ), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels"
  49. assert disparity_pred.shape == torch.Size(
  50. [1, 2, 64, 64]
  51. ), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}"
  52. assert all(
  53. d.shape == torch.Size([1, 2, 64, 64]) for d in preds
  54. ), "All predicted disparities are expected to have the same shape"
  55. # test a backward pass with a dummy loss as well
  56. preds = torch.stack(preds, dim=0)
  57. targets = torch.ones_like(preds, requires_grad=False)
  58. loss = torch.nn.functional.mse_loss(preds, targets)
  59. try:
  60. loss.backward()
  61. except Exception as e:
  62. assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}"
  63. TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...