|
@@ -1,6 +1,6 @@
|
|
import unittest
|
|
import unittest
|
|
|
|
|
|
-from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
|
+from torch.utils.data import DataLoader, TensorDataset, RandomSampler
|
|
|
|
|
|
from super_gradients.common.registry.registry import register_dataset
|
|
from super_gradients.common.registry.registry import register_dataset
|
|
from super_gradients.training.dataloaders.dataloaders import (
|
|
from super_gradients.training.dataloaders.dataloaders import (
|
|
@@ -46,16 +46,18 @@ from super_gradients.training.dataloaders.dataloaders import (
|
|
from super_gradients.training.datasets import (
|
|
from super_gradients.training.datasets import (
|
|
COCODetectionDataset,
|
|
COCODetectionDataset,
|
|
ImageNetDataset,
|
|
ImageNetDataset,
|
|
- PascalAUG2012SegmentationDataSet,
|
|
|
|
PascalVOC2012SegmentationDataSet,
|
|
PascalVOC2012SegmentationDataSet,
|
|
SuperviselyPersonsDataset,
|
|
SuperviselyPersonsDataset,
|
|
PascalVOCDetectionDataset,
|
|
PascalVOCDetectionDataset,
|
|
Cifar10,
|
|
Cifar10,
|
|
Cifar100,
|
|
Cifar100,
|
|
|
|
+ PascalVOCAndAUGUnifiedDataset,
|
|
)
|
|
)
|
|
import torch
|
|
import torch
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
|
|
+from super_gradients.training.datasets.detection_datasets.pascal_voc_detection import PascalVOCUnifiedDetectionTrainDataset
|
|
|
|
+
|
|
|
|
|
|
@register_dataset("FixedLenDataset")
|
|
@register_dataset("FixedLenDataset")
|
|
class FixedLenDataset(TensorDataset):
|
|
class FixedLenDataset(TensorDataset):
|
|
@@ -70,6 +72,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl_train = coco2017_train()
|
|
dl_train = coco2017_train()
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
|
|
self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
|
|
|
|
+ self.assertTrue(dl_train.batch_sampler.sampler._shuffle)
|
|
|
|
|
|
def test_coco2017_val_creation(self):
|
|
def test_coco2017_val_creation(self):
|
|
dl_val = coco2017_val()
|
|
dl_val = coco2017_val()
|
|
@@ -80,6 +83,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl_train = coco2017_train_ssd_lite_mobilenet_v2()
|
|
dl_train = coco2017_train_ssd_lite_mobilenet_v2()
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
|
|
self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
|
|
|
|
+ self.assertTrue(dl_train.batch_sampler.sampler._shuffle)
|
|
|
|
|
|
def test_coco2017_val_ssdlite_mobilenet_creation(self):
|
|
def test_coco2017_val_ssdlite_mobilenet_creation(self):
|
|
dl_train = coco2017_val_ssd_lite_mobilenet_v2()
|
|
dl_train = coco2017_val_ssd_lite_mobilenet_v2()
|
|
@@ -90,6 +94,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_train()
|
|
dl = imagenet_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_imagenet_val_creation(self):
|
|
def test_imagenet_val_creation(self):
|
|
dl = imagenet_val()
|
|
dl = imagenet_val()
|
|
@@ -100,6 +105,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_efficientnet_train()
|
|
dl = imagenet_efficientnet_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_imagenet_efficientnet_val_creation(self):
|
|
def test_imagenet_efficientnet_val_creation(self):
|
|
dl = imagenet_efficientnet_val()
|
|
dl = imagenet_efficientnet_val()
|
|
@@ -110,6 +116,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_mobilenetv2_train()
|
|
dl = imagenet_mobilenetv2_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_imagenet_mobilenetv2_val_creation(self):
|
|
def test_imagenet_mobilenetv2_val_creation(self):
|
|
dl = imagenet_mobilenetv2_val()
|
|
dl = imagenet_mobilenetv2_val()
|
|
@@ -120,6 +127,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_mobilenetv3_train()
|
|
dl = imagenet_mobilenetv3_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_imagenet_mobilenetv3_val_creation(self):
|
|
def test_imagenet_mobilenetv3_val_creation(self):
|
|
dl = imagenet_mobilenetv3_val()
|
|
dl = imagenet_mobilenetv3_val()
|
|
@@ -130,6 +138,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_regnetY_train()
|
|
dl = imagenet_regnetY_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_imagenet_regnetY_val_creation(self):
|
|
def test_imagenet_regnetY_val_creation(self):
|
|
dl = imagenet_regnetY_val()
|
|
dl = imagenet_regnetY_val()
|
|
@@ -140,6 +149,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_resnet50_train()
|
|
dl = imagenet_resnet50_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_imagenet_resnet50_val_creation(self):
|
|
def test_imagenet_resnet50_val_creation(self):
|
|
dl = imagenet_resnet50_val()
|
|
dl = imagenet_resnet50_val()
|
|
@@ -151,6 +161,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
|
|
dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(dl.sampler._shuffle)
|
|
|
|
|
|
def test_imagenet_resnet50_kd_val_creation(self):
|
|
def test_imagenet_resnet50_kd_val_creation(self):
|
|
dl = imagenet_resnet50_kd_val()
|
|
dl = imagenet_resnet50_kd_val()
|
|
@@ -161,6 +172,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = imagenet_vit_base_train()
|
|
dl = imagenet_vit_base_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_imagenet_vit_base_val_creation(self):
|
|
def test_imagenet_vit_base_val_creation(self):
|
|
dl = imagenet_vit_base_val()
|
|
dl = imagenet_vit_base_val()
|
|
@@ -181,6 +193,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl_train = cifar10_train()
|
|
dl_train = cifar10_train()
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train.dataset, Cifar10))
|
|
self.assertTrue(isinstance(dl_train.dataset, Cifar10))
|
|
|
|
+ self.assertTrue(isinstance(dl_train.sampler, RandomSampler))
|
|
|
|
|
|
def test_cifar10_val_creation(self):
|
|
def test_cifar10_val_creation(self):
|
|
dl_val = cifar10_val()
|
|
dl_val = cifar10_val()
|
|
@@ -191,6 +204,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl_train = cifar100_train()
|
|
dl_train = cifar100_train()
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train, DataLoader))
|
|
self.assertTrue(isinstance(dl_train.dataset, Cifar100))
|
|
self.assertTrue(isinstance(dl_train.dataset, Cifar100))
|
|
|
|
+ self.assertTrue(isinstance(dl_train.sampler, RandomSampler))
|
|
|
|
|
|
def test_cifar100_val_creation(self):
|
|
def test_cifar100_val_creation(self):
|
|
dl_val = cifar100_val()
|
|
dl_val = cifar100_val()
|
|
@@ -215,17 +229,19 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
def test_pascal_aug_segmentation_train_creation(self):
|
|
def test_pascal_aug_segmentation_train_creation(self):
|
|
dl = pascal_aug_segmentation_train()
|
|
dl = pascal_aug_segmentation_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
- self.assertTrue(isinstance(dl.dataset, PascalAUG2012SegmentationDataSet))
|
|
|
|
|
|
+ self.assertTrue(isinstance(dl.dataset, PascalVOCAndAUGUnifiedDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_pascal_aug_segmentation_val_creation(self):
|
|
def test_pascal_aug_segmentation_val_creation(self):
|
|
dl = pascal_aug_segmentation_val()
|
|
dl = pascal_aug_segmentation_val()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
- self.assertTrue(isinstance(dl.dataset, PascalAUG2012SegmentationDataSet))
|
|
|
|
|
|
+ self.assertTrue(isinstance(dl.dataset, PascalVOC2012SegmentationDataSet))
|
|
|
|
|
|
def test_pascal_voc_segmentation_train_creation(self):
|
|
def test_pascal_voc_segmentation_train_creation(self):
|
|
dl = pascal_voc_segmentation_train()
|
|
dl = pascal_voc_segmentation_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, PascalVOC2012SegmentationDataSet))
|
|
self.assertTrue(isinstance(dl.dataset, PascalVOC2012SegmentationDataSet))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_pascal_voc_segmentation_val_creation(self):
|
|
def test_pascal_voc_segmentation_val_creation(self):
|
|
dl = pascal_voc_segmentation_val()
|
|
dl = pascal_voc_segmentation_val()
|
|
@@ -236,6 +252,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
dl = supervisely_persons_train()
|
|
dl = supervisely_persons_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl.dataset, SuperviselyPersonsDataset))
|
|
self.assertTrue(isinstance(dl.dataset, SuperviselyPersonsDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_supervisely_persons_val_dataloader_creation(self):
|
|
def test_supervisely_persons_val_dataloader_creation(self):
|
|
dl = supervisely_persons_val()
|
|
dl = supervisely_persons_val()
|
|
@@ -245,7 +262,8 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
def test_pascal_voc_train_creation(self):
|
|
def test_pascal_voc_train_creation(self):
|
|
dl = pascal_voc_detection_train()
|
|
dl = pascal_voc_detection_train()
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
- self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
|
|
|
|
|
|
+ self.assertTrue(isinstance(dl.dataset, PascalVOCUnifiedDetectionTrainDataset))
|
|
|
|
+ self.assertTrue(dl.batch_sampler.sampler._shuffle)
|
|
|
|
|
|
def test_pascal_voc_val_creation(self):
|
|
def test_pascal_voc_val_creation(self):
|
|
dl = pascal_voc_detection_val()
|
|
dl = pascal_voc_detection_val()
|
|
@@ -254,12 +272,14 @@ class DataLoaderFactoryTest(unittest.TestCase):
|
|
|
|
|
|
def test_get_with_external_dataset_creation(self):
|
|
def test_get_with_external_dataset_creation(self):
|
|
dataset = Cifar10(root="./data/cifar10", train=False, download=True)
|
|
dataset = Cifar10(root="./data/cifar10", train=False, download=True)
|
|
- dl = get(dataset=dataset, dataloader_params={"batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True})
|
|
|
|
|
|
+ dl = get(dataset=dataset, dataloader_params={"batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True, "shuffle": True})
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
self.assertTrue(isinstance(dl, DataLoader))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
def test_get_with_registered_dataset(self):
|
|
def test_get_with_registered_dataset(self):
|
|
- dl = get(dataloader_params={"dataset": "FixedLenDataset", "batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True})
|
|
|
|
|
|
+ dl = get(dataloader_params={"dataset": "FixedLenDataset", "batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True, "shuffle": True})
|
|
self.assertTrue(isinstance(dl.dataset, FixedLenDataset))
|
|
self.assertTrue(isinstance(dl.dataset, FixedLenDataset))
|
|
|
|
+ self.assertTrue(isinstance(dl.sampler, RandomSampler))
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|