Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#278 Adding new version of detection dataset, PascalVOC and PascalVOC dataset interface

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-171-DetectionDatasetV2_with_PascalVOC
@@ -4,6 +4,7 @@ from super_gradients.training.datasets.data_augmentation import DataAugmentation
 from super_gradients.training.datasets.sg_dataset import ListDataset, DirectoryDataSet
 from super_gradients.training.datasets.sg_dataset import ListDataset, DirectoryDataSet
 from super_gradients.training.datasets.all_datasets import CLASSIFICATION_DATASETS, OBJECT_DETECTION_DATASETS, \
 from super_gradients.training.datasets.all_datasets import CLASSIFICATION_DATASETS, OBJECT_DETECTION_DATASETS, \
     SEMANTIC_SEGMENTATION_DATASETS
     SEMANTIC_SEGMENTATION_DATASETS
+from super_gradients.training.datasets.detection_datasets import DetectionDataset, COCODetectionDataset, PascalVOCDetectionDataset
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation import PascalVOC2012SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation import PascalVOC2012SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation import PascalAUG2012SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation import PascalAUG2012SegmentationDataSet
@@ -22,4 +23,5 @@ __all__ = ['DataAugmentation', 'ListDataset', 'DirectoryDataSet', 'CLASSIFICATIO
            'PascalVOC2012SegmentationDataSetInterface', 'PascalAUG2012SegmentationDataSetInterface',
            'PascalVOC2012SegmentationDataSetInterface', 'PascalAUG2012SegmentationDataSetInterface',
            'TestYoloDetectionDatasetInterface', 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface',
            'TestYoloDetectionDatasetInterface', 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface',
            'SegmentationTestDatasetInterface',
            'SegmentationTestDatasetInterface',
-           'ImageNetDatasetInterface']
+           'ImageNetDatasetInterface',
+           'DetectionDataset', 'COCODetectionDataset', 'PascalVOCDetectionDataset']
Discard
@@ -3,10 +3,13 @@ from super_gradients.training.datasets.dataset_interfaces.dataset_interface impo
     ClassificationDatasetInterface, Cifar10DatasetInterface, Cifar100DatasetInterface, \
     ClassificationDatasetInterface, Cifar10DatasetInterface, Cifar100DatasetInterface, \
     ImageNetDatasetInterface, TinyImageNetDatasetInterface, CoCoSegmentationDatasetInterface, \
     ImageNetDatasetInterface, TinyImageNetDatasetInterface, CoCoSegmentationDatasetInterface, \
     PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface, \
     PascalAUG2012SegmentationDataSetInterface, PascalVOC2012SegmentationDataSetInterface, \
-    TestYoloDetectionDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface
+    TestYoloDetectionDatasetInterface, SegmentationTestDatasetInterface, DetectionTestDatasetInterface, ClassificationTestDatasetInterface,\
+    CoCoDetectionDatasetInterface, PascalVOCUnifiedDetectionDatasetInterface
+
 
 
 __all__ = ['DatasetInterface', 'TestDatasetInterface', 'LibraryDatasetInterface', 'ClassificationDatasetInterface', 'Cifar10DatasetInterface',
 __all__ = ['DatasetInterface', 'TestDatasetInterface', 'LibraryDatasetInterface', 'ClassificationDatasetInterface', 'Cifar10DatasetInterface',
            'Cifar100DatasetInterface', 'ImageNetDatasetInterface', 'TinyImageNetDatasetInterface',
            'Cifar100DatasetInterface', 'ImageNetDatasetInterface', 'TinyImageNetDatasetInterface',
            'CoCoSegmentationDatasetInterface', 'PascalAUG2012SegmentationDataSetInterface',
            'CoCoSegmentationDatasetInterface', 'PascalAUG2012SegmentationDataSetInterface',
            'PascalVOC2012SegmentationDataSetInterface', 'TestYoloDetectionDatasetInterface', 'SegmentationTestDatasetInterface',
            'PascalVOC2012SegmentationDataSetInterface', 'TestYoloDetectionDatasetInterface', 'SegmentationTestDatasetInterface',
-           'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface']
+           'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface', 'CoCoDetectionDatasetInterface',
+           'PascalVOCUnifiedDetectionDatasetInterface']
Discard
@@ -1,36 +1,46 @@
 import os
 import os
+
 import numpy as np
 import numpy as np
 
 
 import torch
 import torch
 import torchvision
 import torchvision
 import torchvision.datasets as datasets
 import torchvision.datasets as datasets
 from torch.utils.data.distributed import DistributedSampler
 from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import ConcatDataset, BatchSampler, DataLoader
+import torchvision.transforms as transforms
 
 
+
+from super_gradients.common import DatasetDataInterface
+from super_gradients.common.environment import AWS_ENV_NAME
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
+
+from super_gradients.training import utils as core_utils
+from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
+
+from super_gradients.training.utils import get_param
+from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
+
 from super_gradients.training.datasets import datasets_utils, DataAugmentation
 from super_gradients.training.datasets import datasets_utils, DataAugmentation
+from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
 from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
 from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
-from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, worker_init_reset_seed
-from super_gradients.training.datasets.detection_datasets.coco_detection import COCODetectionDataset
+from super_gradients.training.datasets.mixup import CollateMixup
+from super_gradients.training.datasets.detection_datasets import COCODetectionDataset, PascalVOCDetectionDataset
+
 from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
 from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
 from super_gradients.training.datasets.segmentation_datasets import PascalVOC2012SegmentationDataSet, \
 from super_gradients.training.datasets.segmentation_datasets import PascalVOC2012SegmentationDataSet, \
     PascalAUG2012SegmentationDataSet, CoCoSegmentationDataSet
     PascalAUG2012SegmentationDataSet, CoCoSegmentationDataSet
-from super_gradients.training import utils as core_utils
-from super_gradients.common import DatasetDataInterface
-from super_gradients.common.environment import AWS_ENV_NAME
-from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
-from super_gradients.training.datasets.mixup import CollateMixup
-from super_gradients.training.exceptions.dataset_exceptions import IllegalDatasetParameterException
 from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
 from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
-from torch.utils.data import BatchSampler, DataLoader
-from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
-from super_gradients.training.utils import get_param
-import torchvision.transforms as transforms
 from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import \
 from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import \
     SuperviselyPersonsDataset
     SuperviselyPersonsDataset
+
 from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
 from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
-from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
-from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionMixup, DetectionRandomAffine, DetectionTargetsFormatTransform, \
-    DetectionPaddedRescale, DetectionHSV, DetectionHorizontalFlip
+from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, worker_init_reset_seed
+
+from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionMixup, DetectionRandomAffine,\
+    DetectionTargetsFormatTransform, DetectionPaddedRescale, DetectionHSV, DetectionHorizontalFlip
+
+from super_gradients.training.exceptions.dataset_exceptions import IllegalDatasetParameterException
+
 
 
 default_dataset_params = {"batch_size": 64, "val_batch_size": 200, "test_batch_size": 200, "dataset_dir": "./data/",
 default_dataset_params = {"batch_size": 64, "val_batch_size": 200, "test_batch_size": 200, "dataset_dir": "./data/",
                           "s3_link": None}
                           "s3_link": None}
@@ -683,7 +693,92 @@ class SuperviselyPersonsDatasetInterface(DatasetInterface):
         self.classes = self.trainset.classes
         self.classes = self.trainset.classes
 
 
 
 
-class CoCoDetectionDatasetInterface(DatasetInterface):
+class DetectionDatasetInterface(DatasetInterface):
+    def build_data_loaders(self, batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None,
+                           test_batch_size=None, distributed_sampler: bool = False):
+
+        train_sampler = InfiniteSampler(len(self.trainset), seed=0)
+
+        train_batch_sampler = BatchSampler(
+            sampler=train_sampler,
+            batch_size=self.dataset_params.batch_size,
+            drop_last=False,
+        )
+
+        self.train_loader = DataLoader(self.trainset,
+                                       batch_sampler=train_batch_sampler,
+                                       num_workers=num_workers,
+                                       pin_memory=True,
+                                       worker_init_fn=worker_init_reset_seed,
+                                       collate_fn=self.dataset_params.train_collate_fn)
+
+        if distributed_sampler:
+            sampler = torch.utils.data.distributed.DistributedSampler(self.valset, shuffle=False)
+        else:
+            sampler = torch.utils.data.SequentialSampler(self.valset)
+
+        val_loader = torch.utils.data.DataLoader(self.valset,
+                                                 num_workers=num_workers,
+                                                 pin_memory=True,
+                                                 sampler=sampler,
+                                                 batch_size=self.dataset_params.val_batch_size,
+                                                 collate_fn=self.dataset_params.val_collate_fn)
+
+        self.val_loader = val_loader
+
+
+class PascalVOCUnifiedDetectionDatasetInterface(DetectionDatasetInterface):
+
+    def __init__(self, dataset_params=None):
+        if dataset_params is None:
+            dataset_params = dict()
+        super().__init__(dataset_params=dataset_params)
+
+        self.data_dir = self.dataset_params.data_dir
+        train_input_dim = (self.dataset_params.train_image_size, self.dataset_params.train_image_size)
+        val_input_dim = (self.dataset_params.val_image_size, self.dataset_params.val_image_size)
+        train_max_num_samples = get_param(self.dataset_params, "train_max_num_samples")
+        val_max_num_samples = get_param(self.dataset_params, "val_max_num_samples")
+
+        if self.dataset_params.download:
+            PascalVOCDetectionDataset.download(data_dir=self.data_dir)
+
+        train_dataset_names = ["train2007", "val2007", "train2012", "val2012"]
+        # We divide train_max_num_samples between the datasets
+        if train_max_num_samples:
+            max_num_samples_per_train_dataset = [len(segment) for segment in np.array_split(range(train_max_num_samples), len(train_dataset_names))]
+        else:
+            max_num_samples_per_train_dataset = [None] * len(train_dataset_names)
+        train_sets = [PascalVOCDetectionDataset(data_dir=self.data_dir,
+                                                input_dim=train_input_dim,
+                                                cache=self.dataset_params.cache_train_images,
+                                                cache_path=self.dataset_params.cache_dir + "cache_train",
+                                                transforms=self.dataset_params.train_transforms,
+                                                images_sub_directory='images/' + trainset_name + '/',
+                                                class_inclusion_list=self.dataset_params.class_inclusion_list,
+                                                max_num_samples=max_num_samples_per_train_dataset[i])
+                      for i, trainset_name in enumerate(train_dataset_names)]
+
+        testset2007 = PascalVOCDetectionDataset(data_dir=self.data_dir,
+                                                input_dim=val_input_dim,
+                                                cache=self.dataset_params.cache_val_images,
+                                                cache_path=self.dataset_params.cache_dir + "cache_valid",
+                                                transforms=self.dataset_params.val_transforms,
+                                                images_sub_directory='images/test2007/',
+                                                class_inclusion_list=self.dataset_params.class_inclusion_list,
+                                                max_num_samples=val_max_num_samples)
+
+        self.classes = train_sets[1].classes
+        self.trainset = ConcatDataset(train_sets)
+        self.valset = testset2007
+
+        self.trainset.collate_fn = self.dataset_params.train_collate_fn
+        self.trainset.classes = self.classes
+        self.trainset.img_size = self.dataset_params.train_image_size
+        self.trainset.cache_labels = self.dataset_params.cache_train_images
+
+
+class CoCoDetectionDatasetInterface(DetectionDatasetInterface):
     def __init__(self, dataset_params={}):
     def __init__(self, dataset_params={}):
         super(CoCoDetectionDatasetInterface, self).__init__(dataset_params=dataset_params)
         super(CoCoDetectionDatasetInterface, self).__init__(dataset_params=dataset_params)
 
 
@@ -743,36 +838,4 @@ class CoCoDetectionDatasetInterface(DatasetInterface):
                 cache=self.dataset_params.cache_val_images,
                 cache=self.dataset_params.cache_val_images,
                 cache_dir_path=self.dataset_params.cache_dir_path,
                 cache_dir_path=self.dataset_params.cache_dir_path,
                 with_crowd=with_crowd)
                 with_crowd=with_crowd)
-
-    def build_data_loaders(self, batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None,
-                           test_batch_size=None, distributed_sampler: bool = False):
-
-        train_sampler = InfiniteSampler(len(self.trainset), seed=0)
-
-        train_batch_sampler = BatchSampler(
-            sampler=train_sampler,
-            batch_size=self.dataset_params.batch_size,
-            drop_last=False,
-        )
-
-        self.train_loader = DataLoader(self.trainset,
-                                       batch_sampler=train_batch_sampler,
-                                       num_workers=num_workers,
-                                       pin_memory=True,
-                                       worker_init_fn=worker_init_reset_seed,
-                                       collate_fn=self.dataset_params.train_collate_fn)
-
-        if distributed_sampler:
-            sampler = torch.utils.data.distributed.DistributedSampler(self.valset, shuffle=False)
-        else:
-            sampler = torch.utils.data.SequentialSampler(self.valset)
-
-        val_loader = torch.utils.data.DataLoader(self.valset,
-                                                 num_workers=num_workers,
-                                                 pin_memory=True,
-                                                 sampler=sampler,
-                                                 batch_size=self.dataset_params.val_batch_size,
-                                                 collate_fn=self.dataset_params.val_collate_fn)
-
-        self.val_loader = val_loader
         self.classes = COCO_DETECTION_CLASSES_LIST
         self.classes = COCO_DETECTION_CLASSES_LIST
Discard
@@ -35,3 +35,7 @@ COCO_DETECTION_CLASSES_LIST = [
     'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
     'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
     'teddy bear', 'hair drier', 'toothbrush'
     'teddy bear', 'hair drier', 'toothbrush'
 ]
 ]
+
+PASCAL_VOC_2012_CLASSES_LIST = [
+    'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
+    'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
Discard
@@ -1,3 +1,6 @@
 from super_gradients.training.datasets.detection_datasets.coco_detection import COCODetectionDataset
 from super_gradients.training.datasets.detection_datasets.coco_detection import COCODetectionDataset
+from super_gradients.training.datasets.detection_datasets.pascal_voc_detection import PascalVOCDetectionDataset
+from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
 
 
-__all__ = ['COCODetectionDataset']
+
+__all__ = ['COCODetectionDataset', 'DetectionDataset', 'PascalVOCDetectionDataset']
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
  1. import os
  2. from typing import List, Dict, Union, Any, Optional, Tuple
  3. from multiprocessing.pool import ThreadPool
  4. import random
  5. import cv2
  6. import matplotlib.pyplot as plt
  7. from pathlib import Path
  8. import numpy as np
  9. from tqdm import tqdm
  10. from torch.utils.data import Dataset
  11. from super_gradients.training.utils.detection_utils import get_cls_posx_in_target, DetectionTargetsFormat
  12. from super_gradients.common.abstractions.abstract_logger import get_logger
  13. from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform
  14. from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
  15. logger = get_logger(__name__)
  16. class DetectionDataset(Dataset):
  17. """Detection dataset.
  18. This is a boilerplate class to facilitate the implementation of datasets.
  19. HOW TO CREATE A DATASET THAT INHERITS FROM DetectionDataSet ?
  20. - Inherit from DetectionDataSet
  21. - implement the method self._load_annotation to return at least the fields "target" and "img_path"
  22. - Call super().__init__ with the required params.
  23. //!\\ super().__init__ will call self._load_annotation, so make sure that every required
  24. attributes are set up before calling super().__init__ (ideally just call it last)
  25. WORKFLOW:
  26. - On instantiation:
  27. - All annotations are cached. If class_inclusion_list was specified, there is also subclassing at this step.
  28. - If cache is True, the images are also cached
  29. - On call (__getitem__) for a specific image index:
  30. - The image and annotations are grouped together in a dict called SAMPLE
  31. - the sample is processed according to th transform
  32. - Only the specified fields are returned by __getitem__
  33. TERMINOLOGY
  34. - TARGET: Groundtruth, made of bboxes. The format can vary from one dataset to another
  35. - ANNOTATION: Combination of targets (groundtruth) and metadata of the image, but without the image itself.
  36. > Has to include the fields "target" and "img_path"
  37. > Can include other fields like "crowd_target", "image_info", "segmentation", ...
  38. - SAMPLE: Outout of the dataset:
  39. > Has to include the fields "target" and "image"
  40. > Can include other fields like "crowd_target", "image_info", "segmentation", ...
  41. - INDEX: Refers to the index in the dataset.
  42. - SAMPLE ID: Refers to the id of sample before droping any annotaion.
  43. Let's imagine a situation where the downloaded data is made of 120 images but 20 were drop
  44. because they had no annotation. In that case:
  45. > We have 120 samples so sample_id will be between 0 and 119
  46. > But only 100 will be indexed so index will be between 0 and 99
  47. > Therefore, we also have len(self) = 100
  48. """
  49. def __init__(
  50. self,
  51. data_dir: str,
  52. input_dim: tuple,
  53. original_target_format: DetectionTargetsFormat,
  54. max_num_samples: int = None,
  55. cache: bool = False,
  56. cache_path: str = None,
  57. transforms: List[DetectionTransform] = [],
  58. all_classes_list: Optional[List[str]] = None,
  59. class_inclusion_list: Optional[List[str]] = None,
  60. ignore_empty_annotations: bool = True,
  61. target_fields: List[str] = None,
  62. output_fields: List[str] = None,
  63. ):
  64. """Detection dataset.
  65. :param data_dir: Where the data is stored
  66. :param input_dim: Image size (when loaded, before transforms).
  67. :param original_target_format: Format of targets stored on disk. raw data format, the output format might
  68. differ based on transforms.
  69. :param max_num_samples: If not None, set the maximum size of the dataset by only indexing the first n annotations/images.
  70. :param cache: Whether to cache images or not.
  71. :param cache_path: Path to the directory where cached images will be stored in an optimized format.
  72. :param transforms: List of transforms to apply sequentially on sample.
  73. :param all_classes_list: All the class names.
  74. :param class_inclusion_list: If not None,every class not included will be ignored.
  75. :param ignore_empty_annotations: If True and class_inclusion_list not None, images without any target
  76. will be ignored.
  77. :param target_fields: List of the fields target fields. This has to include regular target,
  78. but can also include crowd target, segmentation target, ...
  79. It has to include at least "target" but can include other.
  80. :paran output_fields: Fields that will be outputed by __getitem__.
  81. It has to include at least "image" and "target" but can include other.
  82. """
  83. super().__init__()
  84. self.data_dir = data_dir
  85. if not Path(data_dir).exists():
  86. raise FileNotFoundError(f"Please make sure to download the data in the data directory ({self.data_dir}).")
  87. # Number of images that are avalaible(regardless of ignored images)
  88. self.n_available_samples = self._setup_data_source()
  89. if not isinstance(self.n_available_samples, int) or self.n_available_samples < 1:
  90. raise ValueError(f"_setup_data_source() should return the number of available samples but got {self.n_available_samples}")
  91. self.input_dim = input_dim
  92. self.original_target_format = original_target_format
  93. self.max_num_samples = max_num_samples
  94. self.all_classes_list = all_classes_list
  95. self.class_inclusion_list = class_inclusion_list
  96. self.classes = self.class_inclusion_list or self.all_classes_list
  97. if len(set(self.classes) - set(all_classes_list)) > 0:
  98. wrong_classes = set(self.classes) - set(all_classes_list)
  99. raise ValueError(f"class_inclusion_list includes classes that are not in all_classes_list: {wrong_classes}")
  100. self.ignore_empty_annotations = ignore_empty_annotations
  101. self.target_fields = target_fields or ["target"]
  102. if "target" not in self.target_fields:
  103. raise KeyError('"target" is expected to be in the fields to subclass but it was not included')
  104. self.annotations = self._cache_annotations()
  105. self.cache = cache
  106. self.cache_path = cache_path
  107. self.cached_imgs = self._cache_images() if self.cache else None
  108. self.transforms = transforms
  109. self.output_fields = output_fields or ["image", "target"]
  110. if len(self.output_fields) < 2 or self.output_fields[0] != "image" or self.output_fields[1] != "target":
  111. raise ValueError('output_fields must start with "image" and then "target", followed by any other field')
  112. def _setup_data_source(self) -> int:
  113. """Set up the data source and store relevant objects as attributes.
  114. :return: Number of available samples, (i.e. how many images we have, regardless of any filter we might want to use)"""
  115. raise NotImplementedError
  116. def _load_annotation(self, sample_id: int) -> Dict[str, Union[np.ndarray, Any]]:
  117. """Load annotations associated to a specific sample.
  118. Please note that the targets should be resized according to self.input_dim!
  119. :param sample_id: Id of the sample to load annotations from.
  120. :return: Annotation, a dict with any field but has to include at least "target" and "img_path".
  121. """
  122. raise NotImplementedError
  123. def _cache_annotations(self) -> List[Dict[str, Union[np.ndarray, Any]]]:
  124. """Load all the annotations to memory to avoid opening files back and forth.
  125. :return: List of annotations
  126. """
  127. annotations = []
  128. for sample_id, img_id in enumerate(tqdm(range(self.n_available_samples), desc="Caching annotations")):
  129. if self.max_num_samples is not None and len(annotations) >= self.max_num_samples:
  130. break
  131. img_annotation = self._load_annotation(img_id)
  132. if "target" not in img_annotation or "img_path" not in img_annotation:
  133. raise KeyError('_load_annotation is expected to return at least the field "target" and "img_path"')
  134. if self.class_inclusion_list is not None:
  135. img_annotation = self._sub_class_annotation(img_annotation)
  136. is_annotation_empty = all(len(img_annotation[field]) == 0 for field in self.target_fields)
  137. if self.ignore_empty_annotations and is_annotation_empty:
  138. continue
  139. annotations.append(img_annotation)
  140. if len(annotations) == 0:
  141. raise EmptyDatasetException(f"Out of {self.n_available_samples} images, not a single one was found with"
  142. f"any of these classes: {self.class_inclusion_list}")
  143. return annotations
  144. def _sub_class_annotation(self, annotation: dict) -> Union[dict, None]:
  145. """Subclass every field listed in self.target_fields. It could be targets, crowd_targets, ...
  146. :param annotation: Dict representing the annotation of a specific image
  147. :return: Subclassed annotation if non empty after subclassing, otherwise None
  148. """
  149. cls_posx = get_cls_posx_in_target(self.original_target_format)
  150. for field in self.target_fields:
  151. annotation[field] = self._sub_class_target(targets=annotation[field], cls_posx=cls_posx)
  152. return annotation
  153. def _sub_class_target(self, targets: np.ndarray, cls_posx: int) -> np.ndarray:
  154. """Sublass targets of a specific image.
  155. :param targets: Target array to subclass of shape [n_targets, 5], 5 representing a bbox
  156. :param cls_posx: Position of the class id in a bbox
  157. ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label
  158. :return: Subclassed target
  159. """
  160. targets_kept = []
  161. for target in targets:
  162. cls_id = int(target[cls_posx])
  163. cls_name = self.all_classes_list[cls_id]
  164. if cls_name in self.class_inclusion_list:
  165. # Replace the target cls_id in self.all_classes_list by cls_id in self.class_inclusion_list
  166. target[cls_posx] = self.class_inclusion_list.index(cls_name)
  167. targets_kept.append(target)
  168. return np.array(targets_kept) if len(targets_kept) > 0 else np.zeros((0, 5), dtype=np.float32)
  169. def _cache_images(self) -> np.ndarray:
  170. """Cache the images. The cached image are stored in a file to be loaded faster mext time.
  171. :return: Cached images
  172. """
  173. cache_path = Path(self.cache_path)
  174. if cache_path is None:
  175. raise ValueError("You must specify a cache_path if you want to cache your images."
  176. "If you did not mean to use cache, please set cache=False ")
  177. cache_path.mkdir(parents=True, exist_ok=True)
  178. logger.warning("\n********************************************************************************\n"
  179. "You are using cached images in RAM to accelerate training.\n"
  180. "This requires large system RAM.\n"
  181. "********************************************************************************\n")
  182. max_h, max_w = self.input_dim[0], self.input_dim[1]
  183. img_resized_cache_path = cache_path / "img_resized_cache.array"
  184. if not img_resized_cache_path.exists():
  185. logger.info("Caching images for the first time.")
  186. NUM_THREADs = min(8, os.cpu_count())
  187. loaded_images = ThreadPool(NUM_THREADs).imap(func=lambda x: self._load_image(x), iterable=range(len(self)))
  188. # Initialize placeholder for images
  189. cached_imgs = np.memmap(str(img_resized_cache_path), shape=(len(self), max_h, max_w, 3),
  190. dtype=np.uint8, mode="w+")
  191. # Store images in the placeholder
  192. loaded_images_pbar = tqdm(enumerate(loaded_images), total=len(self))
  193. for i, image in loaded_images_pbar:
  194. cached_imgs[i][: image.shape[0], : image.shape[1], :] = image.copy()
  195. cached_imgs.flush()
  196. loaded_images_pbar.close()
  197. else:
  198. logger.warning("You are using cached imgs! Make sure your dataset is not changed!!\n"
  199. "Everytime the self.input_size is changed in your exp file, you need to delete\n"
  200. "the cached data and re-generate them.\n")
  201. logger.info("Loading cached imgs...")
  202. cached_imgs = np.memmap(str(img_resized_cache_path), shape=(len(self), max_h, max_w, 3),
  203. dtype=np.uint8, mode="r+")
  204. return cached_imgs
  205. def _load_resized_img(self, index: int) -> np.ndarray:
  206. """Load image, and resizes it to self.input_dim
  207. :param index: Image index
  208. :return: Resized image
  209. """
  210. img = self._load_image(index)
  211. r = min(self.input_dim[0] / img.shape[0], self.input_dim[1] / img.shape[1])
  212. desired_size = (int(img.shape[1] * r), int(img.shape[0] * r))
  213. resized_img = cv2.resize(src=img, dsize=desired_size, interpolation=cv2.INTER_LINEAR).astype(np.uint8)
  214. return resized_img
  215. def _load_image(self, index: int) -> np.ndarray:
  216. """Loads image at index with its original resolution.
  217. :param index: Image index
  218. :return: Image in array format
  219. """
  220. img_path = self.annotations[index]["img_path"]
  221. img_file = os.path.join(img_path)
  222. img = cv2.imread(img_file)
  223. if img is None:
  224. raise FileNotFoundError(f"{img_file} was no found. Please make sure that the dataset was"
  225. f"downloaded and that the path is correct")
  226. return img
  227. def __del__(self):
  228. """Clear the cached images"""
  229. if hasattr(self, "cached_imgs"):
  230. del self.cached_imgs
  231. def __len__(self):
  232. """Get the length of the dataset."""
  233. return len(self.annotations)
  234. def __getitem__(self, index: int) -> Tuple:
  235. """Get the sample post transforms at a specific index of the dataset.
  236. The output of this function will be collated to form batches."""
  237. sample = self.apply_transforms(self.get_sample(index))
  238. for field in self.output_fields:
  239. if field not in sample.keys():
  240. raise KeyError(f'The field {field} must be present in the sample but was not found.'
  241. 'Please check the output fields of your transforms.')
  242. return tuple(sample[field] for field in self.output_fields)
  243. def get_random_item(self):
  244. return self[self._random_index()]
  245. def get_sample(self, index: int) -> Dict[str, Union[np.ndarray, Any]]:
  246. """Get raw sample, before any transform (beside subclassing).
  247. :param index: Image index
  248. :return: Sample, i.e. a dictionary including at least "image" and "target"
  249. """
  250. img = self.get_resized_image(index)
  251. annotation = self.annotations[index]
  252. return {"image": img, **annotation}
  253. def get_resized_image(self, index: int) -> np.ndarray:
  254. """
  255. Get the resized image at a specific sample_id, either from cache or by loading from disk, based on self.cached_imgs
  256. :param index: Image index
  257. :return: Resized image
  258. """
  259. if self.cache:
  260. return self.cached_imgs[index].copy()
  261. else:
  262. return self._load_resized_img(index)
  263. def apply_transforms(self, sample: Dict[str, Union[np.ndarray, Any]]) -> Dict[str, Union[np.ndarray, Any]]:
  264. """
  265. Applies self.transforms sequentially to sample
  266. If a transforms has the attribute 'additional_samples_count', additional samples will be loaded and stored in
  267. sample["additional_samples"] prior to applying it. Combining with the attribute "non_empty_annotations" will load
  268. only additional samples with objects in them.
  269. :param sample: Sample to apply the transforms on to (loaded with self.get_sample)
  270. :return: Transformed sample
  271. """
  272. for transform in self.transforms:
  273. self._add_additional_inputs_for_transform(sample, transform)
  274. sample = transform(sample)
  275. sample.pop("additional_samples") # additional_samples is not useful after the transform
  276. return sample
  277. def _add_additional_inputs_for_transform(self, sample: Dict[str, Union[np.ndarray, Any]],
  278. transform: DetectionTransform):
  279. """Add additional inputs required by a transform to the sample"""
  280. additional_samples_count = transform.additional_samples_count if hasattr(transform,
  281. "additional_samples_count") else 0
  282. non_empty_annotations = transform.non_empty_annotations if hasattr(transform, "non_empty_annotations") else False
  283. additional_samples = self.get_random_samples(additional_samples_count, non_empty_annotations)
  284. sample["additional_samples"] = additional_samples
  285. def get_random_samples(self, count: int,
  286. non_empty_annotations_only: bool = False) -> List[Dict[str, Union[np.ndarray, Any]]]:
  287. """Load random samples.
  288. :param count: The number of samples wanted
  289. :param non_empty_annotations_only: If true, only return samples with at least 1 annotation
  290. :return: A list of samples satisfying input params
  291. """
  292. return [self.get_random_sample(non_empty_annotations_only) for _ in range(count)]
  293. def get_random_sample(self, non_empty_annotations_only: bool = False):
  294. if non_empty_annotations_only:
  295. return self.get_sample(self._get_random_non_empty_annotation_available_indexes())
  296. else:
  297. return self.get_sample(self._random_index())
  298. def _get_random_non_empty_annotation_available_indexes(self) -> int:
  299. """Get the index of a non-empty annotation.
  300. :return: Image index"""
  301. target, index = [], -1
  302. while len(target) == 0:
  303. index = self._random_index()
  304. target = self.annotations[index]["target"]
  305. return index
  306. def _random_index(self):
  307. """Get a random index of this dataset"""
  308. return random.randint(0, len(self) - 1)
  309. @property
  310. def output_target_format(self):
  311. target_format = self.original_target_format
  312. for transform in self.transforms:
  313. if isinstance(transform, DetectionTargetsFormatTransform):
  314. target_format = transform.output_format
  315. return target_format
  316. def plot(self, max_samples_per_plot: int = 16, n_plots: int = 1, plot_transformed_data: bool = True):
  317. """Combine samples of images with bbox into plots and display the result.
  318. :param max_samples_per_plot: Maximum number of images to be displayed per plot
  319. :param n_plots: Number of plots to display (each plot being a combination of img with bbox)
  320. :param plot_transformed_data: If True, the plot will be over samples after applying transforms (i.e. on __getitem__).
  321. If False, the plot will be over the raw samples (i.e. on get_sample)
  322. :return:
  323. """
  324. plot_counter = 0
  325. input_format = self.output_target_format if plot_transformed_data else self.original_target_format
  326. target_format_transform = DetectionTargetsFormatTransform(input_format=input_format,
  327. output_format=DetectionTargetsFormat.XYXY_LABEL)
  328. for plot_i in range(n_plots):
  329. fig = plt.figure(figsize=(10, 10))
  330. n_subplot = int(np.ceil(max_samples_per_plot ** 0.5))
  331. for img_i in range(max_samples_per_plot):
  332. index = img_i + plot_i * 16
  333. if plot_transformed_data:
  334. image, targets, *_ = self[img_i + plot_i * 16]
  335. image = image.transpose(1, 2, 0).astype(np.int32)
  336. else:
  337. sample = self.get_sample(index)
  338. image, targets = sample["image"], sample["target"]
  339. sample = target_format_transform({"image": image, "target": targets})
  340. # shape = [padding_size x 4] (The dataset will most likely pad the targets to a fixed dim)
  341. boxes = sample["target"][:, 0:4]
  342. # shape = [n_box x 4] (We remove padded boxes, which corresponds to boxes with only 0)
  343. boxes = boxes[(boxes != 0).any(axis=1)]
  344. plt.subplot(n_subplot, n_subplot, img_i + 1).imshow(image)
  345. plt.plot(boxes[:, [0, 2, 2, 0, 0]].T, boxes[:, [1, 1, 3, 3, 1]].T, '.-')
  346. plt.axis('off')
  347. fig.tight_layout()
  348. plt.show()
  349. plt.close()
  350. plot_counter += 1
  351. if plot_counter == n_plots:
  352. return
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  1. import os
  2. import glob
  3. from pathlib import Path
  4. from xml.etree import ElementTree
  5. from tqdm import tqdm
  6. import numpy as np
  7. from super_gradients.training.utils.utils import download_and_untar_from_url, get_image_size_from_path
  8. from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
  9. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  10. from super_gradients.common.abstractions.abstract_logger import get_logger
  11. from super_gradients.training.datasets.datasets_conf import PASCAL_VOC_2012_CLASSES_LIST
  12. logger = get_logger(__name__)
  13. class PascalVOCDetectionDataset(DetectionDataset):
  14. """Dataset for Pascal VOC object detection"""
  15. def __init__(self, images_sub_directory: str, *args, **kwargs):
  16. """Dataset for Pascal VOC object detection
  17. :param images_sub_directory: Sub directory of data_dir that includes images.
  18. """
  19. self.images_sub_directory = images_sub_directory
  20. self.img_and_target_path_list = None
  21. kwargs['all_classes_list'] = PASCAL_VOC_2012_CLASSES_LIST
  22. kwargs['original_target_format'] = DetectionTargetsFormat.XYXY_LABEL
  23. super().__init__(*args, **kwargs)
  24. def _setup_data_source(self):
  25. """Initialize img_and_target_path_list and warn if label file is missing
  26. :return: List of tuples made of (img_path,target_path)
  27. """
  28. img_files_folder = self.data_dir + self.images_sub_directory
  29. if not Path(img_files_folder).exists():
  30. raise FileNotFoundError(f"{self.data_dir} does not include {self.images_sub_directory}. "
  31. f"Please make sure that f{self.data_dir} refers to PascalVOC dataset and that "
  32. "it was downloaded using PascalVOCDetectionDataSetV2.download()")
  33. img_files = glob.glob(img_files_folder + "*.jpg")
  34. if len(img_files) == 0:
  35. raise FileNotFoundError(f"No image file found at {img_files_folder}")
  36. target_files = [img_file.replace("images", "labels").replace(".jpg", ".txt") for img_file in img_files]
  37. img_and_target_path_list = [(img_file, target_file)
  38. for img_file, target_file in zip(img_files, target_files)
  39. if os.path.exists(target_file)]
  40. if len(img_and_target_path_list) == 0:
  41. raise FileNotFoundError("No target file associated to the images was found")
  42. num_missing_files = len(img_files) - len(img_and_target_path_list)
  43. if num_missing_files > 0:
  44. logger.warning(f'{num_missing_files} labels files were not loaded our of {len(img_files)} image files')
  45. self.img_and_target_path_list = img_and_target_path_list
  46. return len(self.img_and_target_path_list)
  47. def _load_annotation(self, sample_id: int) -> dict:
  48. """Load annotations associated to a specific sample.
  49. :return: Annotation including:
  50. - target in XYXY_LABEL format
  51. - img_path
  52. """
  53. img_path, target_path = self.img_and_target_path_list[sample_id]
  54. with open(target_path, 'r') as targets_file:
  55. target = np.array([x.split() for x in targets_file.read().splitlines()], dtype=np.float32)
  56. width, height = get_image_size_from_path(img_path)
  57. # We have to rescale the targets because the images will be rescaled.
  58. r = min(self.input_dim[1] / height, self.input_dim[0] / width)
  59. target[:, :4] *= r
  60. initial_img_shape = (width, height)
  61. resized_img_shape = (int(width * r), int(height * r))
  62. return {"img_path": img_path, "target": target,
  63. "initial_img_shape": initial_img_shape, "resized_img_shape": resized_img_shape}
  64. @staticmethod
  65. def download(data_dir: str):
  66. """Download Pascal dataset in XYXY_LABEL format.
  67. Data extracted form http://host.robots.ox.ac.uk/pascal/VOC/
  68. """
  69. def _parse_and_save_labels(path, new_label_path, year, image_id):
  70. """Parse and save the labels of an image in XYXY_LABEL format."""
  71. with open(f'{path}/VOC{year}/Annotations/{image_id}.xml') as f:
  72. xml_parser = ElementTree.parse(f).getroot()
  73. labels = []
  74. for obj in xml_parser.iter('object'):
  75. cls = obj.find('name').text
  76. if cls in PASCAL_VOC_2012_CLASSES_LIST and not int(obj.find('difficult').text) == 1:
  77. xml_box = obj.find('bndbox')
  78. def get_coord(box_coord):
  79. return xml_box.find(box_coord).text
  80. xmin, ymin, xmax, ymax = get_coord("xmin"), get_coord("ymin"), get_coord("xmax"), get_coord("ymax")
  81. labels.append(" ".join([xmin, ymin, xmax, ymax, str(PASCAL_VOC_2012_CLASSES_LIST.index(cls))]))
  82. with open(new_label_path, 'w') as f:
  83. f.write("\n".join(labels))
  84. urls = ["http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar", # 439M 5011 images
  85. "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar", # 430M, 4952 images
  86. "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"] # 1.86G, 17125 images
  87. data_dir = Path(data_dir)
  88. download_and_untar_from_url(urls, dir=data_dir / 'images')
  89. # Convert
  90. data_path = data_dir / 'images' / 'VOCdevkit'
  91. for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
  92. dest_imgs_path = data_dir / 'images' / f'{image_set}{year}'
  93. dest_imgs_path.mkdir(exist_ok=True, parents=True)
  94. dest_labels_path = data_dir / 'labels' / f'{image_set}{year}'
  95. dest_labels_path.mkdir(exist_ok=True, parents=True)
  96. with open(data_path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f:
  97. image_ids = f.read().strip().split()
  98. for id in tqdm(image_ids, desc=f'{image_set}{year}'):
  99. img_path = data_path / f'VOC{year}/JPEGImages/{id}.jpg'
  100. new_img_path = dest_imgs_path / img_path.name
  101. new_label_path = (dest_labels_path / img_path.name).with_suffix('.txt')
  102. img_path.rename(new_img_path) # Move image to dest folder
  103. _parse_and_save_labels(data_path, new_label_path, year, id)
Discard
@@ -12,6 +12,19 @@ class IllegalDatasetParameterException(Exception):
         super().__init__(self.message)
         super().__init__(self.message)
 
 
 
 
+class EmptyDatasetException(Exception):
+    """
+    Exception raised when a dataset does not have any image for a specific config
+
+    Attributes:
+        message -- explanation of the error
+    """
+
+    def __init__(self, desc):
+        self.message = "Empty Dataset: " + desc
+        super().__init__(self.message)
+
+
 class UnsupportedBatchItemsFormat(ValueError):
 class UnsupportedBatchItemsFormat(ValueError):
     """Exception raised illegal batch items returned from data loader.
     """Exception raised illegal batch items returned from data loader.
 
 
Discard
@@ -6,13 +6,13 @@ from enum import Enum
 from typing import Callable, List, Union, Tuple, Optional, Dict
 from typing import Callable, List, Union, Tuple, Optional, Dict
 
 
 import cv2
 import cv2
-from torch.utils.data._utils.collate import default_collate
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 
 
+import numpy as np
 import torch
 import torch
 import torchvision
 import torchvision
-import numpy as np
 from torch import nn
 from torch import nn
+from torch.utils.data._utils.collate import default_collate
 from omegaconf import ListConfig
 from omegaconf import ListConfig
 
 
 
 
@@ -35,6 +35,21 @@ class DetectionTargetsFormat(Enum):
     NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"
     NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"
 
 
 
 
+def get_cls_posx_in_target(target_format: DetectionTargetsFormat) -> int:
+    """Get the label of a given target
+    :param target_format:   Representation of the target (ex: LABEL_XYXY)
+    :return:                Position of the class id in a bbox
+                                ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label
+    """
+    format_split = target_format.value.split("_")
+    if format_split[0] == "LABEL":
+        return 0
+    elif format_split[-1] == "LABEL":
+        return -1
+    else:
+        raise NotImplementedError(f"No implementation to find index of LABEL in {target_format.value}")
+
+
 def _set_batch_labels_index(labels_batch):
 def _set_batch_labels_index(labels_batch):
     for i, labels in enumerate(labels_batch):
     for i, labels in enumerate(labels_batch):
         labels[:, 0] = i
         labels[:, 0] = i
Discard
@@ -1,20 +1,28 @@
 import math
 import math
 import time
 import time
+from functools import lru_cache
 from pathlib import Path
 from pathlib import Path
-from typing import Mapping, Optional, Tuple, Union
+from typing import Mapping, Optional, Tuple, Union, List
 from zipfile import ZipFile
 from zipfile import ZipFile
 import os
 import os
 from jsonschema import validate
 from jsonschema import validate
+import tarfile
+from PIL import Image, ExifTags
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
+
 # These functions changed from torch 1.2 to torch 1.3
 # These functions changed from torch 1.2 to torch 1.3
 
 
 import random
 import random
 import numpy as np
 import numpy as np
 from importlib import import_module
 from importlib import import_module
 
 
+from super_gradients.common.abstractions.abstract_logger import get_logger
+
+logger = get_logger(__name__)
+
 
 
 def convert_to_tensor(array):
 def convert_to_tensor(array):
     """Converts numpy arrays and lists to Torch tensors before calculation losses
     """Converts numpy arrays and lists to Torch tensors before calculation losses
@@ -338,6 +346,35 @@ def download_and_unzip_from_url(url, dir='.', unzip=True, delete=True):
         download_one(u, dir)
         download_one(u, dir)
 
 
 
 
+def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = '.'):
+    """
+    Download a file from url and untar.
+
+    :param urls:    Url to download the file from.
+    :param dir:     Destination directory.
+    """
+    dir = Path(dir)
+    dir.mkdir(parents=True, exist_ok=True)
+
+    for url in urls:
+        url_path = Path(url)
+        filepath = dir / url_path.name
+
+        if url_path.is_file():
+            url_path.rename(filepath)
+        elif not filepath.exists():
+            logger.info(f'Downloading {url} to {filepath}...')
+            torch.hub.download_url_to_file(url, str(filepath), progress=True)
+
+        modes = {".tar.gz": "r:gz", ".tar": "r:"}
+        assert filepath.suffix in modes.keys(), f"{filepath} has {filepath.suffix} suffix which is not supported"
+
+        logger.info(f'Extracting to {dir}...')
+        with tarfile.open(filepath, mode=modes[filepath.suffix]) as f:
+            f.extractall(dir)
+        filepath.unlink()
+
+
 def make_divisible(x: int, divisor: int, ceil: bool = True) -> int:
 def make_divisible(x: int, divisor: int, ceil: bool = True) -> int:
     """
     """
     Returns x evenly divisible by divisor.
     Returns x evenly divisible by divisor.
@@ -362,3 +399,43 @@ def check_img_size_divisibility(img_size: int, stride: int = 32) -> Tuple[bool,
         return False, (new_size, make_divisible(img_size, int(stride), ceil=False))
         return False, (new_size, make_divisible(img_size, int(stride), ceil=False))
     else:
     else:
         return True, None
         return True, None
+
+
+@lru_cache(None)
+def get_orientation_key() -> int:
+    """Get the orientation key according to PIL, which is useful to get the image size for instance
+    :return: Orientation key according to PIL"""
+    for key, value in ExifTags.TAGS.items():
+        if value == 'Orientation':
+            return key
+
+
+def exif_size(image: Image) -> Tuple[int, int]:
+    """Get the size of image.
+    :param image:   The image to get size from
+    :return:        (width, height)
+    """
+
+    orientation_key = get_orientation_key()
+
+    image_size = image.size
+    try:
+        exif_data = image._getexif()
+        if exif_data is not None:
+            rotation = dict(exif_data.items())[orientation_key]
+            # ROTATION 270
+            if rotation == 6:
+                image_size = (image_size[1], image_size[0])
+            # ROTATION 90
+            elif rotation == 8:
+                image_size = (image_size[1], image_size[0])
+    except Exception as ex:
+        print('Caught Exception trying to rotate: ' + str(image) + str(ex))
+    height, width = image_size
+    return width, height
+
+
+def get_image_size_from_path(img_path: str) -> Tuple[int, int]:
+    """Get the image size of an image at a specific path"""
+    with open(img_path, 'rb') as f:
+        return exif_size(Image.open(f))
Discard
@@ -24,6 +24,8 @@ from tests.unit_tests.lr_cooldown_test import LRCooldownTest
 from tests.unit_tests.detection_targets_format_transform_test import DetectionTargetsTransformTest
 from tests.unit_tests.detection_targets_format_transform_test import DetectionTargetsTransformTest
 from tests.unit_tests.forward_pass_prep_fn_test import ForwardpassPrepFNTest
 from tests.unit_tests.forward_pass_prep_fn_test import ForwardpassPrepFNTest
 from tests.unit_tests.mask_loss_test import MaskAttentionLossTest
 from tests.unit_tests.mask_loss_test import MaskAttentionLossTest
+from tests.unit_tests.detection_sub_sampling_test import TestDetectionDatasetSubsampling
+from tests.unit_tests.detection_sub_classing_test import TestDetectionDatasetSubclassing
 
 
 
 
 class CoreUnitTestSuiteRunner:
 class CoreUnitTestSuiteRunner:
@@ -70,6 +72,8 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(UpdateParamGroupsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(UpdateParamGroupsTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaskAttentionLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaskAttentionLossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(IoULossTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  1. import unittest
  2. import super_gradients
  3. from super_gradients.training.datasets import PascalVOCDetectionDataset
  4. from super_gradients.training.transforms import DetectionMosaic, DetectionPaddedRescale, DetectionTargetsFormatTransform
  5. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  6. from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
  7. class DatasetIntegrationTest(unittest.TestCase):
  8. def setUp(self) -> None:
  9. super_gradients.init_trainer()
  10. self.batch_size = 64
  11. self.pascal_class_inclusion_lists = [['aeroplane', 'bicycle'],
  12. ['bird', 'boat', 'bottle', 'bus'],
  13. ['pottedplant'],
  14. ['person']]
  15. transforms = [DetectionMosaic(input_dim=(640, 640), prob=0.8),
  16. DetectionPaddedRescale(input_dim=(640, 640), max_targets=120),
  17. DetectionTargetsFormatTransform(output_format=DetectionTargetsFormat.XYXY_LABEL)]
  18. self.pascal_base_config = dict(data_dir='/home/louis.dupont/data/pascal_unified_coco_format/',
  19. images_sub_directory='images/train2012/',
  20. input_dim=(640, 640),
  21. transforms=transforms)
  22. def test_multiple_pascal_dataset_subclass_before_transforms(self):
  23. """Run test_pascal_dataset_subclass on multiple inclusion lists"""
  24. for class_inclusion_list in self.pascal_class_inclusion_lists:
  25. dataset = PascalVOCDetectionDataset(class_inclusion_list=class_inclusion_list, **self.pascal_base_config)
  26. dataset.plot(max_samples_per_plot=16, n_plots=1, plot_transformed_data=False)
  27. def test_multiple_pascal_dataset_subclass_after_transforms(self):
  28. """Run test_pascal_dataset_subclass on multiple inclusion lists"""
  29. for class_inclusion_list in self.pascal_class_inclusion_lists:
  30. dataset = PascalVOCDetectionDataset(class_inclusion_list=class_inclusion_list, **self.pascal_base_config)
  31. dataset.plot(max_samples_per_plot=16, n_plots=1, plot_transformed_data=True)
  32. def test_subclass_non_existing_class(self):
  33. """Check that EmptyDatasetException is raised when unknown label."""
  34. with self.assertRaises(ValueError):
  35. PascalVOCDetectionDataset(class_inclusion_list=["new_class"], **self.pascal_base_config)
  36. def test_sub_sampling_dataset(self):
  37. """Check that sub sampling works."""
  38. full_dataset = PascalVOCDetectionDataset(**self.pascal_base_config)
  39. with self.assertRaises(EmptyDatasetException):
  40. PascalVOCDetectionDataset(max_num_samples=0, **self.pascal_base_config)
  41. for max_num_samples in [1, 10, 1000, 1_000_000]:
  42. sampled_dataset = PascalVOCDetectionDataset(max_num_samples=max_num_samples, **self.pascal_base_config)
  43. self.assertEqual(len(sampled_dataset), min(max_num_samples, len(full_dataset)))
  44. if __name__ == '__main__':
  45. unittest.main()
Discard
@@ -1,12 +1,71 @@
 import unittest
 import unittest
-from super_gradients.training.datasets import Cifar10DatasetInterface
+
+from super_gradients.training.datasets.dataset_interfaces.dataset_interface import PascalVOCUnifiedDetectionDatasetInterface
+from super_gradients.training.transforms.transforms import DetectionPaddedRescale, DetectionTargetsFormatTransform, DetectionMosaic, DetectionRandomAffine,\
+    DetectionHSV
+from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
+from super_gradients.training.utils.detection_utils import DetectionCollateFN
+from super_gradients.training.utils import sg_model_utils
+from super_gradients.training import utils as core_utils
 
 
 
 
 class TestDatasetInterface(unittest.TestCase):
 class TestDatasetInterface(unittest.TestCase):
-    def test_cifar(self):
-        test_dataset_interface = Cifar10DatasetInterface()
-        cifar_dataset_sample = test_dataset_interface.get_test_sample()
-        self.assertListEqual([3, 32, 32], list(cifar_dataset_sample[0].shape))
+    def setUp(self) -> None:
+        self.root_dir = "/home/data/"
+        self.train_batch_size, self.val_batch_size = 16, 32
+        self.train_image_size, self.val_image_size = 640, 640
+        self.train_input_dim = (self.train_image_size, self.train_image_size)
+        self.val_input_dim = (self.val_image_size, self.val_image_size)
+        self.train_max_num_samples = 100
+        self.val_max_num_samples = 90
+
+    def setup_pascal_voc_interface(self):
+        """setup PascalVOCUnifiedDetectionDataSetInterfaceV2 and return dataloaders"""
+        dataset_params = {
+            "data_dir": self.root_dir + "pascal_unified_coco_format/",
+            "cache_dir": self.root_dir + "pascal_unified_coco_format/",
+            "batch_size": self.train_batch_size,
+            "val_batch_size": self.val_batch_size,
+            "train_image_size": self.train_image_size,
+            "val_image_size": self.val_image_size,
+            "train_max_num_samples": self.train_max_num_samples,
+            "val_max_num_samples": self.val_max_num_samples,
+            "train_transforms": [
+                DetectionMosaic(input_dim=self.train_input_dim, prob=1),
+                DetectionRandomAffine(degrees=0.373, translate=0.245, scales=0.898, shear=0.602, target_size=self.train_input_dim),
+                DetectionHSV(prob=1, hgain=0.0138, sgain=0.664, vgain=0.464),
+                DetectionPaddedRescale(input_dim=self.train_input_dim, max_targets=100),
+                DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
+                                                output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
+            "val_transforms": [
+                DetectionPaddedRescale(input_dim=self.val_input_dim),
+                DetectionTargetsFormatTransform(input_format=DetectionTargetsFormat.XYXY_LABEL,
+                                                output_format=DetectionTargetsFormat.LABEL_CXCYWH)],
+            "train_collate_fn": DetectionCollateFN(),
+            "val_collate_fn": DetectionCollateFN(),
+            "download": False,
+            "cache_train_images": False,
+            "cache_val_images": False,
+            "class_inclusion_list": ["person"]
+        }
+        dataset_interface = PascalVOCUnifiedDetectionDatasetInterface(dataset_params=dataset_params)
+        train_loader, valid_loader, _test_loader, _classes = dataset_interface.get_data_loaders()
+        return train_loader, valid_loader
+
+    def test_pascal_voc(self):
+        """Check that the dataset interface is correctly instantiated, and that the batch items are of expected size"""
+        train_loader, valid_loader = self.setup_pascal_voc_interface()
+
+        for loader, batch_size, image_size, max_num_samples in [(train_loader, self.train_batch_size, self.train_image_size, self.train_max_num_samples),
+                                                                (valid_loader, self.val_batch_size, self.val_image_size, self.val_max_num_samples)]:
+            # The dataset is at most of length max_num_samples, but can be smaller if not enough samples
+            self.assertGreaterEqual(max_num_samples, len(loader.dataset))
+
+            batch_items = next(iter(loader))
+            batch_items = core_utils.tensor_container_to_device(batch_items, 'cuda', non_blocking=True)
+
+            inputs, targets, additional_batch_items = sg_model_utils.unpack_batch_items(batch_items)
+            self.assertListEqual([batch_size, 3, image_size, image_size], list(inputs.shape))
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  1. import unittest
  2. import numpy as np
  3. from super_gradients.training.datasets import DetectionDataset
  4. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  5. from super_gradients.training.exceptions.dataset_exceptions import EmptyDatasetException
  6. class DummyDetectionDataset(DetectionDataset):
  7. def __init__(self, input_dim, *args, **kwargs):
  8. """Dummy Dataset testing subclassing, designed with no annotation that includes class_2."""
  9. self.dummy_targets = [np.array([[0, 0, 10, 10, 0],
  10. [0, 5, 10, 15, 0],
  11. [0, 5, 15, 20, 0]]),
  12. np.array([[0, 0, 10, 10, 0],
  13. [0, 5, 10, 15, 0],
  14. [0, 15, 55, 20, 1]])]
  15. self.image_size = input_dim
  16. kwargs['all_classes_list'] = ["class_0", "class_1", "class_2"]
  17. kwargs['original_target_format'] = DetectionTargetsFormat.XYXY_LABEL
  18. super().__init__(data_dir='', input_dim=input_dim, *args, **kwargs)
  19. def _setup_data_source(self):
  20. return len(self.dummy_targets)
  21. def _load_annotation(self, sample_id: int) -> dict:
  22. """Load 2 different annotations.
  23. - Annotation 0 is made of: 3 targets of class 0, 0 of class_1 and 0 of class_2
  24. - Annotation 1 is made of: 2 targets of class_0, 1 of class_1 and 0 of class_2
  25. """
  26. return {"img_path": "", "target": self.dummy_targets[sample_id]}
  27. # DetectionDatasetV2 will call _load_image but since we don't have any image we patch this method with
  28. # tensor of image shape
  29. def _load_image(self, index: int) -> np.ndarray:
  30. return np.random.random(self.image_size)
  31. class TestDetectionDatasetSubclassing(unittest.TestCase):
  32. def setUp(self) -> None:
  33. self.config_keep_empty_annotation = [
  34. {
  35. "class_inclusion_list": ["class_0", "class_1", "class_2"],
  36. "expected_n_targets_after_subclass": [3, 3]
  37. },
  38. {
  39. "class_inclusion_list": ["class_0"],
  40. "expected_n_targets_after_subclass": [3, 2]
  41. },
  42. {
  43. "class_inclusion_list": ["class_1"],
  44. "expected_n_targets_after_subclass": [0, 1]
  45. },
  46. {
  47. "class_inclusion_list": ["class_2"],
  48. "expected_n_targets_after_subclass": [0, 0]
  49. },
  50. ]
  51. self.config_ignore_empty_annotation = [
  52. {
  53. "class_inclusion_list": ["class_0", "class_1", "class_2"],
  54. "expected_n_targets_after_subclass": [3, 3]
  55. },
  56. {
  57. "class_inclusion_list": ["class_0"],
  58. "expected_n_targets_after_subclass": [3, 2]
  59. },
  60. {
  61. "class_inclusion_list": ["class_1"],
  62. "expected_n_targets_after_subclass": [1]
  63. }
  64. ]
  65. def test_subclass_keep_empty(self):
  66. """Check that subclassing only keeps annotations of wanted class"""
  67. for config in self.config_keep_empty_annotation:
  68. test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=False,
  69. class_inclusion_list=config["class_inclusion_list"])
  70. n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
  71. self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)
  72. def test_subclass_drop_empty(self):
  73. """Check that empty annotations are not indexed (i.e. ignored) when ignore_empty_annotations=True"""
  74. for config in self.config_ignore_empty_annotation:
  75. test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True,
  76. class_inclusion_list=config["class_inclusion_list"])
  77. n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
  78. self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)
  79. # Check last case when class_2, which should raise EmptyDatasetException because not a single image has
  80. # a target in class_inclusion_list
  81. with self.assertRaises(EmptyDatasetException):
  82. DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True,
  83. class_inclusion_list=["class_2"])
  84. def test_wrong_subclass(self):
  85. """Check that ValueError is raised when class_inclusion_list includes a class that does not exist."""
  86. with self.assertRaises(ValueError):
  87. DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["non_existing_class"])
  88. with self.assertRaises(ValueError):
  89. DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["class_0", "non_existing_class"])
  90. def _count_targets_after_subclass_per_index(test_dataset: DummyDetectionDataset):
  91. """Iterate through every index of the dataset and count the associated number of targets per index"""
  92. dataset_target_len = []
  93. for index in range(len(test_dataset)):
  94. _img, targets = test_dataset[index]
  95. dataset_target_len.append(len(targets))
  96. return dataset_target_len
  97. if __name__ == '__main__':
  98. unittest.main()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from super_gradients.training.datasets import DetectionDataset
  5. from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
  6. class DummyDetectionDataset(DetectionDataset):
  7. def __init__(self, dataset_size, input_dim, *args, **kwargs):
  8. """Dummy Dataset testing subsampling."""
  9. self.dataset_size = dataset_size
  10. self.image_size = input_dim
  11. kwargs['all_classes_list'] = ["class_0", "class_1", "class_2"]
  12. kwargs['original_target_format'] = DetectionTargetsFormat.XYXY_LABEL
  13. super().__init__(data_dir='', input_dim=input_dim, *args, **kwargs)
  14. def _setup_data_source(self):
  15. return self.dataset_size
  16. def _load_annotation(self, sample_id: int) -> dict:
  17. """Load dummy annotation"""
  18. return {"img_path": "", "target": torch.zeros(10, 6)}
  19. # DetectionDatasetV2 will call _load_image but since we don't have any image we patch this method with
  20. # tensor of image shape
  21. def _load_image(self, index: int) -> np.ndarray:
  22. return np.random.random(self.image_size)
  23. class TestDetectionDatasetSubsampling(unittest.TestCase):
  24. def test_subsampling(self):
  25. """Check that subsampling works"""
  26. for max_num_samples in [1, 1_000, 1_000_000]:
  27. test_dataset = DummyDetectionDataset(dataset_size=100_000, input_dim=(640, 512), max_num_samples=max_num_samples)
  28. self.assertEqual(len(test_dataset), min(max_num_samples, 100_000))
  29. if __name__ == '__main__':
  30. unittest.main()
Discard