Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

segmentation_processing_PadShortToCropSize_test.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. import unittest
  2. import numpy as np
  3. from super_gradients.training.processing import SegmentationPadShortToCropSize
  4. from super_gradients.training.utils.predict.predictions import SegmentationPrediction
  5. class SegmentationPadShortToCropSizeTest(unittest.TestCase):
  6. def test_pad_normal_input(self):
  7. crop_size = (512, 512)
  8. fill_image = 0
  9. pad_transform = SegmentationPadShortToCropSize(crop_size, fill_image)
  10. # Test for images with different dimensions
  11. input_image_1 = np.zeros((640, 640))
  12. output_image_1, metadata_1 = pad_transform.preprocess_image(input_image_1)
  13. preprocess_output_size_1 = max(crop_size[0], input_image_1.shape[0]), max(crop_size[1], input_image_1.shape[1])
  14. self.assertEqual(output_image_1.shape, preprocess_output_size_1)
  15. input_image_2 = np.ones((800, 400))
  16. output_image_2, metadata_2 = pad_transform.preprocess_image(input_image_2)
  17. preprocess_output_size_2 = max(crop_size[0], input_image_2.shape[0]), max(crop_size[1], input_image_2.shape[1])
  18. self.assertEqual(output_image_2.shape, preprocess_output_size_2)
  19. # Test for crop_size smaller than the input image
  20. input_image_3 = np.zeros((320, 320))
  21. output_image_3, metadata_3 = pad_transform.preprocess_image(input_image_3)
  22. self.assertEqual(output_image_3.shape, crop_size)
  23. def test_pad_1x1_image(self):
  24. crop_size = (512, 512)
  25. fill_image = 0
  26. pad_transform = SegmentationPadShortToCropSize(crop_size, fill_image)
  27. input_image = np.ones((1, 1))
  28. output_image, metadata = pad_transform.preprocess_image(input_image)
  29. self.assertEqual(output_image.shape, crop_size)
  30. # test postprocessing
  31. prediction_obj = SegmentationPrediction(output_image, output_image.shape, output_image.shape)
  32. output_prediction = pad_transform.postprocess_predictions(prediction_obj, metadata)
  33. # Check if the output segmentation map has the correct dimensions after removing padding
  34. self.assertEqual(output_prediction.segmentation_map.shape, input_image.shape)
  35. self.assertEqual(output_prediction.segmentation_map.all(), True)
  36. def test_pad_edge_cases(self):
  37. crop_size = (512, 512)
  38. fill_image = 0
  39. pad_transform = SegmentationPadShortToCropSize(crop_size, fill_image)
  40. # Test for crop_size equal to the input image size
  41. input_image_1 = np.zeros((512, 512))
  42. output_image_1, metadata_1 = pad_transform.preprocess_image(input_image_1)
  43. self.assertEqual(output_image_1.shape, crop_size)
  44. # Test for crop_size smaller than the input image
  45. input_image_2 = np.zeros((400, 400))
  46. output_image_2, metadata_2 = pad_transform.preprocess_image(input_image_2)
  47. self.assertEqual(output_image_2.shape, crop_size)
  48. def test_postprocess_predictions(self):
  49. crop_size = (512, 512)
  50. fill_image = 0
  51. pad_transform = SegmentationPadShortToCropSize(crop_size, fill_image)
  52. # Create a segmentation prediction object with a known segmentation map shape
  53. input_image_shape = (400, 400)
  54. input_image = np.ones(input_image_shape)
  55. segmentation_map, metadata = pad_transform.preprocess_image(input_image)
  56. prediction_obj = SegmentationPrediction(segmentation_map, crop_size, crop_size)
  57. # Apply the postprocess_predictions method with known padding_coordinates
  58. output_prediction = pad_transform.postprocess_predictions(prediction_obj, metadata)
  59. # Check if the output segmentation map has the correct dimensions after removing padding
  60. self.assertEqual(output_prediction.segmentation_map.shape, input_image_shape)
  61. self.assertEqual(output_prediction.segmentation_map.all(), True)
  62. if __name__ == "__main__":
  63. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...