Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detection_caching.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. import unittest
  2. import numpy as np
  3. from pathlib import Path
  4. import tempfile
  5. import os
  6. from super_gradients.training.datasets import DetectionDataset
  7. from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
  8. class DummyDetectionDataset(DetectionDataset):
  9. def __init__(self, input_dim, *args, **kwargs):
  10. """Dummy Dataset testing subclassing, designed with no annotation that includes class_2."""
  11. self.image_size = input_dim
  12. self.n_samples = 321
  13. kwargs["all_classes_list"] = ["class_0", "class_1", "class_2"]
  14. kwargs["original_target_format"] = XYXY_LABEL
  15. super().__init__(input_dim=input_dim, *args, **kwargs)
  16. def _setup_data_source(self):
  17. return self.n_samples
  18. def _load_annotation(self, sample_id: int) -> dict:
  19. """Every image is made of one target, with label sample_id % len(all_classes_list) and
  20. a seed to allow the random image to the same for a given sample_id
  21. """
  22. cls_id = sample_id % len(self.all_classes_list)
  23. return {"img_path": str(sample_id), "target": np.array([[0, 0, 10, 10, cls_id]]), "resized_img_shape": self.image_size, "seed": sample_id}
  24. # We overwrite this to fake images
  25. def _load_image(self, image_path: str) -> np.ndarray:
  26. np.random.seed(int(image_path))
  27. return np.random.random((self.image_size[0], self.image_size[1], 3)) * 255
  28. class TestDetectionDatasetCaching(unittest.TestCase):
  29. def setUp(self) -> None:
  30. self.temp_cache_dir = tempfile.TemporaryDirectory(prefix="cache").name
  31. if not os.path.isdir(self.temp_cache_dir):
  32. os.mkdir(self.temp_cache_dir)
  33. def _count_cached_array(self):
  34. return len(list(Path(self.temp_cache_dir).glob("*.array")))
  35. def _empty_cache(self):
  36. for cache_file in Path(self.temp_cache_dir).glob("*.array"):
  37. cache_file.unlink()
  38. def test_cache_keep_empty(self):
  39. self._empty_cache()
  40. datasets = [
  41. DummyDetectionDataset(
  42. input_dim=(640, 512),
  43. ignore_empty_annotations=False,
  44. class_inclusion_list=class_inclusion_list,
  45. cache=True,
  46. cache_dir=self.temp_cache_dir,
  47. data_dir="/home/",
  48. )
  49. for class_inclusion_list in [["class_0", "class_1", "class_2"], ["class_0"], ["class_1"], ["class_2"], ["class_1", "class_2"]]
  50. ]
  51. self.assertEqual(1, self._count_cached_array())
  52. for first_dataset, second_dataset in zip(datasets[:-1], datasets[1:]):
  53. self.assertTrue(np.array_equal(first_dataset.cached_imgs_padded, second_dataset.cached_imgs_padded))
  54. self._empty_cache()
  55. def test_cache_ignore_empty(self):
  56. self._empty_cache()
  57. datasets = [
  58. DummyDetectionDataset(
  59. input_dim=(640, 512),
  60. ignore_empty_annotations=True,
  61. class_inclusion_list=class_inclusion_list,
  62. cache=True,
  63. cache_dir=self.temp_cache_dir,
  64. data_dir="/home/",
  65. )
  66. for class_inclusion_list in [["class_0", "class_1", "class_2"], ["class_0"], ["class_1"], ["class_2"], ["class_1", "class_2"]]
  67. ]
  68. self.assertEqual(5, self._count_cached_array())
  69. for first_dataset, second_dataset in zip(datasets[:-1], datasets[1:]):
  70. self.assertFalse(np.array_equal(first_dataset.cached_imgs_padded, second_dataset.cached_imgs_padded))
  71. self._empty_cache()
  72. def test_cache_saved(self):
  73. """Check that after the first time a dataset is called with specific params,
  74. the next time it will call the saved array instead of building it."""
  75. self._empty_cache()
  76. self.assertEqual(0, self._count_cached_array())
  77. _ = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, cache=True, cache_dir=self.temp_cache_dir, data_dir="/home/")
  78. self.assertEqual(1, self._count_cached_array())
  79. for _ in range(5):
  80. _ = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, cache=True, cache_dir=self.temp_cache_dir, data_dir="/home/")
  81. self.assertEqual(1, self._count_cached_array())
  82. self._empty_cache()
  83. if __name__ == "__main__":
  84. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...