Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

forward_with_sliding_window_test.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  1. import torch
  2. import unittest
  3. import torch.nn
  4. from super_gradients.training.utils.segmentation_utils import forward_with_sliding_window_wrapper
  5. class SlidingWindowTest(unittest.TestCase):
  6. def setUp(self) -> None:
  7. self.num_classes = 1
  8. def _assert_tensors_equal(self, tensor1, tensor2):
  9. self.assertTrue(torch.allclose(tensor1, tensor2, atol=1e-6))
  10. def test_input_smaller_than_crop_size_and_crop_size_equal_stride_size(self):
  11. input_size = (512, 512)
  12. crop_size = (640, 640)
  13. stride_size = (640, 640)
  14. model = DummyModel()
  15. input_tensor = torch.randn((1, 1) + input_size)
  16. reconstructed_input = forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  17. self._assert_tensors_equal(input_tensor, reconstructed_input)
  18. def test_input_smaller_than_crop_size_and_stride_size_larger_than_crop_size(self):
  19. input_size = (512, 512)
  20. crop_size = (640, 640)
  21. stride_size = (768, 768)
  22. model = DummyModel()
  23. input_tensor = torch.randn((1, 1) + input_size)
  24. with self.assertRaises(ValueError):
  25. forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  26. def test_input_smaller_than_crop_size_and_stride_size_less_than_crop_size(self):
  27. input_size = (512, 512)
  28. crop_size = (640, 640)
  29. stride_size = (384, 384)
  30. model = DummyModel()
  31. input_tensor = torch.randn((1, 1) + input_size)
  32. reconstructed_input = forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  33. self._assert_tensors_equal(input_tensor, reconstructed_input)
  34. def test_input_equal_to_crop_size_and_crop_size_equal_stride_size(self):
  35. input_size = (512, 512)
  36. crop_size = (512, 512)
  37. stride_size = (512, 512)
  38. model = DummyModel()
  39. input_tensor = torch.randn((1, 1) + input_size)
  40. reconstructed_input = forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  41. self._assert_tensors_equal(input_tensor, reconstructed_input)
  42. def test_input_equal_to_crop_size_and_stride_size_larger_than_crop_size(self):
  43. input_size = (512, 512)
  44. crop_size = (512, 512)
  45. stride_size = (640, 640)
  46. model = DummyModel()
  47. input_tensor = torch.randn((1, 1) + input_size)
  48. with self.assertRaises(ValueError):
  49. forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  50. def test_input_equal_to_crop_size_and_stride_size_less_than_crop_size(self):
  51. input_size = (512, 512)
  52. crop_size = (512, 512)
  53. stride_size = (384, 384)
  54. model = DummyModel()
  55. input_tensor = torch.randn((1, 1) + input_size)
  56. reconstructed_input = forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  57. self._assert_tensors_equal(input_tensor, reconstructed_input)
  58. def test_input_larger_than_crop_size_and_crop_size_equal_stride_size(self):
  59. input_size = (513, 513)
  60. crop_size = (512, 512)
  61. stride_size = (512, 512)
  62. model = DummyModel()
  63. input_tensor = torch.randn((1, 1) + input_size)
  64. reconstructed_input = forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  65. self._assert_tensors_equal(input_tensor, reconstructed_input)
  66. def test_input_larger_than_crop_size_and_stride_size_larger_than_crop_size(self):
  67. input_size = (513, 513)
  68. crop_size = (512, 512)
  69. stride_size = (640, 640)
  70. model = DummyModel()
  71. input_tensor = torch.randn((1, 1) + input_size)
  72. with self.assertRaises(ValueError):
  73. forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  74. def test_input_larger_than_crop_size_and_stride_size_less_than_crop_size(self):
  75. input_size = (513, 513)
  76. crop_size = (512, 512)
  77. stride_size = (384, 384)
  78. model = DummyModel()
  79. input_tensor = torch.randn((1, 1) + input_size)
  80. reconstructed_input = forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  81. self._assert_tensors_equal(input_tensor, reconstructed_input)
  82. def test_odd_sized_input(self):
  83. input_size = (13, 13)
  84. crop_size = (3, 3)
  85. stride_size = (2, 2)
  86. model = DummyModel()
  87. input_tensor = torch.randn((1, 1) + input_size)
  88. reconstructed_input = forward_with_sliding_window_wrapper(model.forward, input_tensor, stride_size, crop_size, self.num_classes)
  89. self._assert_tensors_equal(input_tensor, reconstructed_input)
  90. class DummyModel(torch.nn.Module):
  91. def forward(self, x):
  92. return x
  93. if __name__ == "__main__":
  94. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...