Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detection_dataset_test.py 6.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  1. import tempfile
  2. import os
  3. import unittest
  4. from typing import Dict, Union, Any
  5. import numpy as np
  6. import pkg_resources
  7. from hydra import initialize_config_dir, compose
  8. from hydra.core.global_hydra import GlobalHydra
  9. from pydantic.main import deepcopy
  10. import super_gradients
  11. from super_gradients.training.dataloaders.dataloaders import _process_dataset_params
  12. from super_gradients.training.datasets import PascalVOCDetectionDataset, COCODetectionDataset
  13. from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
  14. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  15. from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
  16. from super_gradients.training.utils.hydra_utils import normalize_path
  17. class COCODetectionDataset6Channels(COCODetectionDataset):
  18. def get_sample(self, index: int) -> Dict[str, Union[np.ndarray, Any]]:
  19. img = self.get_resized_image(index)
  20. img = np.concatenate((img, img), 2)
  21. annotation = deepcopy(self.annotations[index])
  22. return {"image": img, **annotation}
  23. class DatasetIntegrationTest(unittest.TestCase):
  24. def setUp(self) -> None:
  25. super_gradients.init_trainer()
  26. self.batch_size = 64
  27. self.max_samples_per_plot = 16
  28. self.n_plot = 1
  29. transforms = [
  30. DetectionMosaic(input_dim=(640, 640), prob=0.8),
  31. DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
  32. DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.XYXY_LABEL),
  33. ]
  34. self.test_dir = tempfile.TemporaryDirectory().name
  35. PascalVOCDetectionDataset.download(self.test_dir)
  36. self.pascal_class_inclusion_lists = [["aeroplane", "bicycle"], ["bird", "boat", "bottle", "bus"], ["pottedplant"], ["person"]]
  37. self.pascal_base_config = dict(data_dir=self.test_dir, images_sub_directory="images/train2012/", input_dim=(640, 640), transforms=transforms)
  38. self.coco_class_inclusion_lists = [["airplane", "bicycle"], ["bird", "boat", "bottle", "bus"], ["potted plant"], ["person"]]
  39. self.dataset_coco_base_config = dict(
  40. data_dir="/data/coco",
  41. subdir="images/val2017",
  42. json_file="instances_val2017.json",
  43. input_dim=(640, 640),
  44. transforms=transforms,
  45. )
  46. def test_multiple_pascal_dataset_subclass_before_transforms(self):
  47. """Run test_pascal_dataset_subclass on multiple inclusion lists"""
  48. for class_inclusion_list in self.pascal_class_inclusion_lists:
  49. dataset = PascalVOCDetectionDataset(
  50. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.pascal_base_config
  51. )
  52. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=False)
  53. def test_multiple_pascal_dataset_subclass_after_transforms(self):
  54. """Run test_pascal_dataset_subclass on multiple inclusion lists"""
  55. for class_inclusion_list in self.pascal_class_inclusion_lists:
  56. dataset = PascalVOCDetectionDataset(
  57. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.pascal_base_config
  58. )
  59. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=True)
  60. def test_multiple_coco_dataset_subclass_before_transforms(self):
  61. """Check subclass on multiple inclusions before transform"""
  62. for class_inclusion_list in self.coco_class_inclusion_lists:
  63. dataset = COCODetectionDataset(
  64. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.dataset_coco_base_config
  65. )
  66. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=False)
  67. def test_multiple_coco_dataset_subclass_after_transforms(self):
  68. """Check subclass on multiple inclusions after transform"""
  69. for class_inclusion_list in self.coco_class_inclusion_lists:
  70. dataset = COCODetectionDataset(
  71. class_inclusion_list=class_inclusion_list, max_num_samples=self.max_samples_per_plot * self.n_plot, **self.dataset_coco_base_config
  72. )
  73. dataset.plot(max_samples_per_plot=self.max_samples_per_plot, n_plots=self.n_plot, plot_transformed_data=True)
  74. def test_subclass_non_existing_class(self):
  75. """Check that EmptyDatasetException is raised when unknown label."""
  76. with self.assertRaises(ValueError):
  77. PascalVOCDetectionDataset(class_inclusion_list=["new_class"], **self.pascal_base_config)
  78. def test_sub_sampling_dataset(self):
  79. """Check that sub sampling works."""
  80. full_dataset = PascalVOCDetectionDataset(**self.pascal_base_config)
  81. with self.assertRaises(EmptyDatasetException):
  82. PascalVOCDetectionDataset(max_num_samples=0, **self.pascal_base_config)
  83. for max_num_samples in [1, 10, 1000, 1_000_000]:
  84. sampled_dataset = PascalVOCDetectionDataset(max_num_samples=max_num_samples, **self.pascal_base_config)
  85. self.assertEqual(len(sampled_dataset), min(max_num_samples, len(full_dataset)))
  86. def test_detection_dataset_transforms_with_unique_channel_count(self):
  87. GlobalHydra.instance().clear()
  88. sg_recipes_dir = pkg_resources.resource_filename("super_gradients.recipes", "")
  89. dataset_config = os.path.join("dataset_params", "coco_detection_dataset_params")
  90. with initialize_config_dir(config_dir=normalize_path(sg_recipes_dir), version_base="1.2"):
  91. # config is relative to a module
  92. cfg = compose(config_name=normalize_path(dataset_config))
  93. dataset_params = _process_dataset_params(cfg, dict(), True)
  94. coco_base_recipe_transforms = dataset_params["transforms"]
  95. dataset_config = deepcopy(self.dataset_coco_base_config)
  96. dataset_config["transforms"] = coco_base_recipe_transforms
  97. dataset = COCODetectionDataset6Channels(**dataset_config)
  98. self.assertEqual(dataset.__getitem__(0)[0].shape[0], 6)
  99. if __name__ == "__main__":
  100. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...