|
@@ -6,6 +6,10 @@ from super_gradients.training.dataloaders.dataloader_factory import (
|
|
|
classification_test_dataloader,
|
|
|
detection_test_dataloader,
|
|
|
segmentation_test_dataloader,
|
|
|
+ cifar10_val,
|
|
|
+ cifar10_train,
|
|
|
+ cifar100_val,
|
|
|
+ cifar100_train,
|
|
|
coco2017_train,
|
|
|
coco2017_val,
|
|
|
coco2017_train_ssd_lite_mobilenet_v2,
|
|
@@ -33,15 +37,23 @@ from super_gradients.training.dataloaders.dataloader_factory import (
|
|
|
pascal_voc_segmentation_train,
|
|
|
pascal_voc_segmentation_val,
|
|
|
supervisely_persons_train,
|
|
|
- supervisely_persons_val, pascal_voc_detection_train, pascal_voc_detection_val
|
|
|
+ supervisely_persons_val,
|
|
|
+ pascal_voc_detection_train,
|
|
|
+ pascal_voc_detection_val,
|
|
|
+)
|
|
|
+from super_gradients.training.datasets import (
|
|
|
+ COCODetectionDataset,
|
|
|
+ ImageNetDataset,
|
|
|
+ PascalAUG2012SegmentationDataSet,
|
|
|
+ PascalVOC2012SegmentationDataSet,
|
|
|
+ SuperviselyPersonsDataset,
|
|
|
+ PascalVOCDetectionDataset,
|
|
|
+ Cifar10,
|
|
|
+ Cifar100,
|
|
|
)
|
|
|
-from super_gradients.training.datasets import COCODetectionDataset, ImageNetDataset, PascalAUG2012SegmentationDataSet, \
|
|
|
- PascalVOC2012SegmentationDataSet, \
|
|
|
- SuperviselyPersonsDataset, PascalVOCDetectionDataset
|
|
|
|
|
|
|
|
|
class DataLoaderFactoryTest(unittest.TestCase):
|
|
|
-
|
|
|
def test_coco2017_train_creation(self):
|
|
|
dl_train = coco2017_train()
|
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
@@ -124,7 +136,9 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
|
|
|
|
def test_imagenet_resnet50_kd_train_creation(self):
|
|
|
# Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
|
|
|
- dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
|
|
|
+ dl = imagenet_resnet50_kd_train(
|
|
|
+ dataloader_params={"sampler": {"InfiniteSampler": {}}}
|
|
|
+ )
|
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
|
@@ -153,6 +167,26 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
|
|
+ def test_cifar10_train_creation(self):
|
|
|
+ dl_train = cifar10_train()
|
|
|
+ self.assertTrue(isinstance(dl_train, DataLoader))
|
|
|
+ self.assertTrue(isinstance(dl_train.dataset, Cifar10))
|
|
|
+
|
|
|
+ def test_cifar10_val_creation(self):
|
|
|
+ dl_val = cifar10_val()
|
|
|
+ self.assertTrue(isinstance(dl_val, DataLoader))
|
|
|
+ self.assertTrue(isinstance(dl_val.dataset, Cifar10))
|
|
|
+
|
|
|
+ def test_cifar100_train_creation(self):
|
|
|
+ dl_train = cifar100_train()
|
|
|
+ self.assertTrue(isinstance(dl_train, DataLoader))
|
|
|
+ self.assertTrue(isinstance(dl_train.dataset, Cifar100))
|
|
|
+
|
|
|
+ def test_cifar100_val_creation(self):
|
|
|
+ dl_val = cifar100_val()
|
|
|
+ self.assertTrue(isinstance(dl_val, DataLoader))
|
|
|
+ self.assertTrue(isinstance(dl_val.dataset, Cifar100))
|
|
|
+
|
|
|
def test_classification_test_dataloader_creation(self):
|
|
|
dl = classification_test_dataloader()
|
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
@@ -209,5 +243,5 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
|
self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
|
|
|
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
+if __name__ == "__main__":
|
|
|
unittest.main()
|