Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#581 Bug/sg 512 shuffle bugfix in recipe datalaoders

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-512_shuffle_bugfix_in_recipe_datalaoders
1 changed files with 27 additions and 7 deletions
  1. 27
    7
      tests/unit_tests/datalaoder_factory_test.py
@@ -1,6 +1,6 @@
 import unittest
 import unittest
 
 
-from torch.utils.data import DataLoader, TensorDataset
+from torch.utils.data import DataLoader, TensorDataset, RandomSampler
 
 
 from super_gradients.common.registry.registry import register_dataset
 from super_gradients.common.registry.registry import register_dataset
 from super_gradients.training.dataloaders.dataloaders import (
 from super_gradients.training.dataloaders.dataloaders import (
@@ -46,16 +46,18 @@ from super_gradients.training.dataloaders.dataloaders import (
 from super_gradients.training.datasets import (
 from super_gradients.training.datasets import (
     COCODetectionDataset,
     COCODetectionDataset,
     ImageNetDataset,
     ImageNetDataset,
-    PascalAUG2012SegmentationDataSet,
     PascalVOC2012SegmentationDataSet,
     PascalVOC2012SegmentationDataSet,
     SuperviselyPersonsDataset,
     SuperviselyPersonsDataset,
     PascalVOCDetectionDataset,
     PascalVOCDetectionDataset,
     Cifar10,
     Cifar10,
     Cifar100,
     Cifar100,
+    PascalVOCAndAUGUnifiedDataset,
 )
 )
 import torch
 import torch
 import numpy as np
 import numpy as np
 
 
+from super_gradients.training.datasets.detection_datasets.pascal_voc_detection import PascalVOCUnifiedDetectionTrainDataset
+
 
 
 @register_dataset("FixedLenDataset")
 @register_dataset("FixedLenDataset")
 class FixedLenDataset(TensorDataset):
 class FixedLenDataset(TensorDataset):
@@ -70,6 +72,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl_train = coco2017_train()
         dl_train = coco2017_train()
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
         self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
+        self.assertTrue(dl_train.batch_sampler.sampler._shuffle)
 
 
     def test_coco2017_val_creation(self):
     def test_coco2017_val_creation(self):
         dl_val = coco2017_val()
         dl_val = coco2017_val()
@@ -80,6 +83,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl_train = coco2017_train_ssd_lite_mobilenet_v2()
         dl_train = coco2017_train_ssd_lite_mobilenet_v2()
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
         self.assertTrue(isinstance(dl_train.dataset, COCODetectionDataset))
+        self.assertTrue(dl_train.batch_sampler.sampler._shuffle)
 
 
     def test_coco2017_val_ssdlite_mobilenet_creation(self):
     def test_coco2017_val_ssdlite_mobilenet_creation(self):
         dl_train = coco2017_val_ssd_lite_mobilenet_v2()
         dl_train = coco2017_val_ssd_lite_mobilenet_v2()
@@ -90,6 +94,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_train()
         dl = imagenet_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_imagenet_val_creation(self):
     def test_imagenet_val_creation(self):
         dl = imagenet_val()
         dl = imagenet_val()
@@ -100,6 +105,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_efficientnet_train()
         dl = imagenet_efficientnet_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_imagenet_efficientnet_val_creation(self):
     def test_imagenet_efficientnet_val_creation(self):
         dl = imagenet_efficientnet_val()
         dl = imagenet_efficientnet_val()
@@ -110,6 +116,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_mobilenetv2_train()
         dl = imagenet_mobilenetv2_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_imagenet_mobilenetv2_val_creation(self):
     def test_imagenet_mobilenetv2_val_creation(self):
         dl = imagenet_mobilenetv2_val()
         dl = imagenet_mobilenetv2_val()
@@ -120,6 +127,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_mobilenetv3_train()
         dl = imagenet_mobilenetv3_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_imagenet_mobilenetv3_val_creation(self):
     def test_imagenet_mobilenetv3_val_creation(self):
         dl = imagenet_mobilenetv3_val()
         dl = imagenet_mobilenetv3_val()
@@ -130,6 +138,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_regnetY_train()
         dl = imagenet_regnetY_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_imagenet_regnetY_val_creation(self):
     def test_imagenet_regnetY_val_creation(self):
         dl = imagenet_regnetY_val()
         dl = imagenet_regnetY_val()
@@ -140,6 +149,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_resnet50_train()
         dl = imagenet_resnet50_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_imagenet_resnet50_val_creation(self):
     def test_imagenet_resnet50_val_creation(self):
         dl = imagenet_resnet50_val()
         dl = imagenet_resnet50_val()
@@ -151,6 +161,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
         dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(dl.sampler._shuffle)
 
 
     def test_imagenet_resnet50_kd_val_creation(self):
     def test_imagenet_resnet50_kd_val_creation(self):
         dl = imagenet_resnet50_kd_val()
         dl = imagenet_resnet50_kd_val()
@@ -161,6 +172,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = imagenet_vit_base_train()
         dl = imagenet_vit_base_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_imagenet_vit_base_val_creation(self):
     def test_imagenet_vit_base_val_creation(self):
         dl = imagenet_vit_base_val()
         dl = imagenet_vit_base_val()
@@ -181,6 +193,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl_train = cifar10_train()
         dl_train = cifar10_train()
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train.dataset, Cifar10))
         self.assertTrue(isinstance(dl_train.dataset, Cifar10))
+        self.assertTrue(isinstance(dl_train.sampler, RandomSampler))
 
 
     def test_cifar10_val_creation(self):
     def test_cifar10_val_creation(self):
         dl_val = cifar10_val()
         dl_val = cifar10_val()
@@ -191,6 +204,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl_train = cifar100_train()
         dl_train = cifar100_train()
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train.dataset, Cifar100))
         self.assertTrue(isinstance(dl_train.dataset, Cifar100))
+        self.assertTrue(isinstance(dl_train.sampler, RandomSampler))
 
 
     def test_cifar100_val_creation(self):
     def test_cifar100_val_creation(self):
         dl_val = cifar100_val()
         dl_val = cifar100_val()
@@ -215,17 +229,19 @@ class DataLoaderFactoryTest(unittest.TestCase):
     def test_pascal_aug_segmentation_train_creation(self):
     def test_pascal_aug_segmentation_train_creation(self):
         dl = pascal_aug_segmentation_train()
         dl = pascal_aug_segmentation_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
-        self.assertTrue(isinstance(dl.dataset, PascalAUG2012SegmentationDataSet))
+        self.assertTrue(isinstance(dl.dataset, PascalVOCAndAUGUnifiedDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_pascal_aug_segmentation_val_creation(self):
     def test_pascal_aug_segmentation_val_creation(self):
         dl = pascal_aug_segmentation_val()
         dl = pascal_aug_segmentation_val()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
-        self.assertTrue(isinstance(dl.dataset, PascalAUG2012SegmentationDataSet))
+        self.assertTrue(isinstance(dl.dataset, PascalVOC2012SegmentationDataSet))
 
 
     def test_pascal_voc_segmentation_train_creation(self):
     def test_pascal_voc_segmentation_train_creation(self):
         dl = pascal_voc_segmentation_train()
         dl = pascal_voc_segmentation_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, PascalVOC2012SegmentationDataSet))
         self.assertTrue(isinstance(dl.dataset, PascalVOC2012SegmentationDataSet))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_pascal_voc_segmentation_val_creation(self):
     def test_pascal_voc_segmentation_val_creation(self):
         dl = pascal_voc_segmentation_val()
         dl = pascal_voc_segmentation_val()
@@ -236,6 +252,7 @@ class DataLoaderFactoryTest(unittest.TestCase):
         dl = supervisely_persons_train()
         dl = supervisely_persons_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, SuperviselyPersonsDataset))
         self.assertTrue(isinstance(dl.dataset, SuperviselyPersonsDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_supervisely_persons_val_dataloader_creation(self):
     def test_supervisely_persons_val_dataloader_creation(self):
         dl = supervisely_persons_val()
         dl = supervisely_persons_val()
@@ -245,7 +262,8 @@ class DataLoaderFactoryTest(unittest.TestCase):
     def test_pascal_voc_train_creation(self):
     def test_pascal_voc_train_creation(self):
         dl = pascal_voc_detection_train()
         dl = pascal_voc_detection_train()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
-        self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
+        self.assertTrue(isinstance(dl.dataset, PascalVOCUnifiedDetectionTrainDataset))
+        self.assertTrue(dl.batch_sampler.sampler._shuffle)
 
 
     def test_pascal_voc_val_creation(self):
     def test_pascal_voc_val_creation(self):
         dl = pascal_voc_detection_val()
         dl = pascal_voc_detection_val()
@@ -254,12 +272,14 @@ class DataLoaderFactoryTest(unittest.TestCase):
 
 
     def test_get_with_external_dataset_creation(self):
     def test_get_with_external_dataset_creation(self):
         dataset = Cifar10(root="./data/cifar10", train=False, download=True)
         dataset = Cifar10(root="./data/cifar10", train=False, download=True)
-        dl = get(dataset=dataset, dataloader_params={"batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True})
+        dl = get(dataset=dataset, dataloader_params={"batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True, "shuffle": True})
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
     def test_get_with_registered_dataset(self):
     def test_get_with_registered_dataset(self):
-        dl = get(dataloader_params={"dataset": "FixedLenDataset", "batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True})
+        dl = get(dataloader_params={"dataset": "FixedLenDataset", "batch_size": 256, "num_workers": 8, "drop_last": False, "pin_memory": True, "shuffle": True})
         self.assertTrue(isinstance(dl.dataset, FixedLenDataset))
         self.assertTrue(isinstance(dl.dataset, FixedLenDataset))
+        self.assertTrue(isinstance(dl.sampler, RandomSampler))
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
Discard