Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#345 Feature/sg 251 refactor cifar10 dataset interface

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-251_refactor_cifar10_dataset_interface
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  1. batch_size: 256 # batch size for trainset
  2. val_batch_size: 512 # batch size for valset in DatasetInterface
  3. # TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
  4. train_dataset_params:
  5. root: /data/cifar100
  6. train: True
  7. transforms:
  8. - RandomCrop:
  9. size: 32
  10. padding: 4
  11. - RandomHorizontalFlip
  12. - ToTensor
  13. - Normalize:
  14. mean:
  15. - 0.5071
  16. - 0.4865
  17. - 0.4409
  18. std:
  19. - 0.2673
  20. - 0.2564
  21. - 0.2762
  22. target_transform: null
  23. download: True
  24. train_dataloader_params:
  25. batch_size: 256
  26. num_workers: 8
  27. drop_last: False
  28. pin_memory: True
  29. val_dataset_params:
  30. root: /data/cifar100
  31. train: False
  32. transforms:
  33. - ToTensor
  34. - Normalize:
  35. mean:
  36. - 0.5071
  37. - 0.4865
  38. - 0.4409
  39. std:
  40. - 0.2673
  41. - 0.2564
  42. - 0.2762
  43. target_transform: null
  44. download: True
  45. val_dataloader_params:
  46. batch_size: 512
  47. num_workers: 8
  48. drop_last: False
  49. pin_memory: True
Discard
@@ -1,2 +1,54 @@
 batch_size: 256 # batch size for trainset
 batch_size: 256 # batch size for trainset
-val_batch_size: 512 # batch size for valset in DatasetInterface
+val_batch_size: 512 # batch size for valset in DatasetInterface
+
+# TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
+
+train_dataset_params:
+  root: /data/cifar10
+  train: True
+  transforms:
+    - RandomCrop:
+        size: 32
+        padding: 4
+    - RandomHorizontalFlip
+    - ToTensor
+    - Normalize:
+        mean:
+          - 0.4914
+          - 0.4822
+          - 0.4465
+        std:
+          - 0.2023
+          - 0.1994
+          - 0.2010
+  target_transform: null
+  download: True
+
+train_dataloader_params:
+  batch_size: 256
+  num_workers: 8
+  drop_last: False
+  pin_memory: True
+
+val_dataset_params:
+  root: /data/cifar10
+  train: False
+  transforms:
+    - ToTensor
+    - Normalize:
+        mean:
+          - 0.4914
+          - 0.4822
+          - 0.4465
+        std:
+          - 0.2023
+          - 0.1994
+          - 0.2010
+  target_transform: null
+  download: True
+
+val_dataloader_params:
+  batch_size: 512
+  num_workers: 8
+  drop_last: False
+  pin_memory: True
Discard
@@ -17,6 +17,7 @@ from super_gradients.training.datasets.detection_datasets.pascal_voc_detection i
 from super_gradients.training.utils import get_param
 from super_gradients.training.utils import get_param
 from super_gradients.training.datasets import ImageNetDataset
 from super_gradients.training.datasets import ImageNetDataset
 from super_gradients.training.datasets.detection_datasets import COCODetectionDataset
 from super_gradients.training.datasets.detection_datasets import COCODetectionDataset
+from super_gradients.training.datasets.classification_datasets.cifar import Cifar10, Cifar100
 from super_gradients.training.datasets.segmentation_datasets import CityscapesDataset, CoCoSegmentationDataSet, \
 from super_gradients.training.datasets.segmentation_datasets import CityscapesDataset, CoCoSegmentationDataSet, \
     PascalAUG2012SegmentationDataSet, \
     PascalAUG2012SegmentationDataSet, \
     PascalVOC2012SegmentationDataSet, SuperviselyPersonsDataset
     PascalVOC2012SegmentationDataSet, SuperviselyPersonsDataset
@@ -237,6 +238,42 @@ def tiny_imagenet_val(dataset_params={}, dataloader_params={}, config_name="tiny
                            dataloader_params=dataloader_params)
                            dataloader_params=dataloader_params)
 
 
 
 
+def cifar10_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
+    return get_data_loader(config_name="cifar10_dataset_params",
+                           dataset_cls=Cifar10,
+                           train=True,
+                           dataset_params=dataset_params,
+                           dataloader_params=dataloader_params
+                           )
+
+
+def cifar10_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
+    return get_data_loader(config_name="cifar10_dataset_params",
+                           dataset_cls=Cifar10,
+                           train=False,
+                           dataset_params=dataset_params,
+                           dataloader_params=dataloader_params
+                           )
+
+
+def cifar100_train(dataset_params: Dict = {}, dataloader_params: Dict = {}):
+    return get_data_loader(config_name="cifar100_dataset_params",
+                           dataset_cls=Cifar100,
+                           train=True,
+                           dataset_params=dataset_params,
+                           dataloader_params=dataloader_params
+                           )
+
+
+def cifar100_val(dataset_params: Dict = {}, dataloader_params: Dict = {}):
+    return get_data_loader(config_name="cifar100_dataset_params",
+                           dataset_cls=Cifar100,
+                           train=False,
+                           dataset_params=dataset_params,
+                           dataloader_params=dataloader_params
+                           )
+
+
 def classification_test_dataloader(batch_size: int = 5, image_size: int = 32) -> DataLoader:
 def classification_test_dataloader(batch_size: int = 5, image_size: int = 32) -> DataLoader:
     images = torch.Tensor(np.zeros((batch_size, 3, image_size, image_size)))
     images = torch.Tensor(np.zeros((batch_size, 3, image_size, image_size)))
     ground_truth = torch.LongTensor(np.zeros((batch_size)))
     ground_truth = torch.LongTensor(np.zeros((batch_size)))
Discard
@@ -5,7 +5,7 @@ from super_gradients.training.datasets.data_augmentation import DataAugmentation
 from super_gradients.training.datasets.sg_dataset import ListDataset, DirectoryDataSet
 from super_gradients.training.datasets.sg_dataset import ListDataset, DirectoryDataSet
 from super_gradients.training.datasets.all_datasets import CLASSIFICATION_DATASETS, OBJECT_DETECTION_DATASETS, \
 from super_gradients.training.datasets.all_datasets import CLASSIFICATION_DATASETS, OBJECT_DETECTION_DATASETS, \
     SEMANTIC_SEGMENTATION_DATASETS
     SEMANTIC_SEGMENTATION_DATASETS
-from super_gradients.training.datasets.classification_datasets import ImageNetDataset
+from super_gradients.training.datasets.classification_datasets import ImageNetDataset, Cifar10, Cifar100
 from super_gradients.training.datasets.detection_datasets import DetectionDataset, COCODetectionDataset, PascalVOCDetectionDataset
 from super_gradients.training.datasets.detection_datasets import DetectionDataset, COCODetectionDataset, PascalVOCDetectionDataset
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation import PascalVOC2012SegmentationDataSet
 from super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation import PascalVOC2012SegmentationDataSet
@@ -31,4 +31,5 @@ __all__ = ['DataAugmentation', 'ListDataset', 'DirectoryDataSet', 'CLASSIFICATIO
            'TestYoloDetectionDatasetInterface', 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface',
            'TestYoloDetectionDatasetInterface', 'DetectionTestDatasetInterface', 'ClassificationTestDatasetInterface',
            'SegmentationTestDatasetInterface',
            'SegmentationTestDatasetInterface',
            'ImageNetDatasetInterface',
            'ImageNetDatasetInterface',
-           'DetectionDataset', 'COCODetectionDataset', 'PascalVOCDetectionDataset', 'ImageNetDataset', 'SuperviselyPersonsDataset']
+           'DetectionDataset', 'COCODetectionDataset', 'PascalVOCDetectionDataset', 'ImageNetDataset',
+           'Cifar10', 'Cifar100', 'SuperviselyPersonsDataset']
Discard
@@ -1,4 +1,5 @@
 from super_gradients.training.datasets.classification_datasets.imagenet_dataset import ImageNetDataset
 from super_gradients.training.datasets.classification_datasets.imagenet_dataset import ImageNetDataset
+from super_gradients.training.datasets.classification_datasets.cifar import Cifar10, Cifar100
 
 
 
 
-__all__ = ['ImageNetDataset']
+__all__ = ['ImageNetDataset', 'Cifar10', 'Cifar100']
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. from typing import Optional, Callable
  2. from torchvision.transforms import Compose
  3. from super_gradients.common.factories.list_factory import ListFactory
  4. from super_gradients.common.factories.transforms_factory import TransformsFactory
  5. from super_gradients.common.decorators.factory_decorator import resolve_param
  6. from torchvision.datasets import CIFAR10, CIFAR100
  7. class Cifar10(CIFAR10):
  8. """
  9. CIFAR10 Dataset
  10. :param root: Path for the data to be extracted
  11. :param train: Bool to load training (True) or validation (False) part of the dataset
  12. :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose
  13. :param target_transform: Transform to apply to target output
  14. :param download: Download (True) the dataset from source
  15. """
  16. @resolve_param("transforms", ListFactory(TransformsFactory()))
  17. def __init__(
  18. self,
  19. root: str,
  20. train: bool = True,
  21. transforms: Optional[Callable] = None,
  22. target_transform: Optional[Callable] = None,
  23. download: bool = False,
  24. ) -> None:
  25. super(Cifar10, self).__init__(
  26. root=root,
  27. train=train,
  28. transform=Compose(transforms),
  29. target_transform=target_transform,
  30. download=download,
  31. )
  32. class Cifar100(CIFAR100):
  33. @resolve_param("transforms", ListFactory(TransformsFactory()))
  34. def __init__(
  35. self,
  36. root: str,
  37. train: bool = True,
  38. transforms: Optional[Callable] = None,
  39. target_transform: Optional[Callable] = None,
  40. download: bool = False,
  41. ) -> None:
  42. """
  43. CIFAR100 Dataset
  44. :param root: Path for the data to be extracted
  45. :param train: Bool to load training (True) or validation (False) part of the dataset
  46. :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose
  47. :param target_transform: Transform to apply to target output
  48. :param download: Download (True) the dataset from source
  49. """
  50. super(Cifar100, self).__init__(
  51. root=root,
  52. train=train,
  53. transform=Compose(transforms),
  54. target_transform=target_transform,
  55. download=download,
  56. )
Discard
@@ -2,6 +2,6 @@
 
 
 from tests.end_to_end_tests.trainer_test import TestTrainer
 from tests.end_to_end_tests.trainer_test import TestTrainer
 
 
-from tests.end_to_end_tests.cifar10_trainer_test import TestCifar10Trainer
+from tests.end_to_end_tests.cifar_trainer_test import TestCifarTrainer
 
 
-__all__ = ['TestTrainer', 'TestCifar10Trainer']
+__all__ = ['TestTrainer', 'TestCifarTrainer']
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
  1. import unittest
  2. from super_gradients.training import models
  3. import super_gradients
  4. from super_gradients import Trainer
  5. from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface
  6. class TestCifar10Trainer(unittest.TestCase):
  7. def test_train_cifar10(self):
  8. super_gradients.init_trainer()
  9. trainer = Trainer("test", model_checkpoints_location='local')
  10. cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
  11. trainer.connect_dataset_interface(cifar_10_dataset_interface)
  12. model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
  13. trainer.train(model=model, training_params={"max_epochs": 1})
  14. if __name__ == '__main__':
  15. unittest.main()
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  1. import unittest
  2. from super_gradients.training import models
  3. import super_gradients
  4. from super_gradients import Trainer
  5. from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface
  6. from super_gradients.training.dataloaders.dataloader_factory import (
  7. cifar10_train,
  8. cifar10_val,
  9. cifar100_train,
  10. cifar100_val,
  11. )
  12. class TestCifarTrainer(unittest.TestCase):
  13. def test_train_cifar10(self):
  14. super_gradients.init_trainer()
  15. trainer = Trainer("test", model_checkpoints_location="local")
  16. cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
  17. trainer.connect_dataset_interface(cifar_10_dataset_interface)
  18. model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
  19. trainer.train(
  20. model=model,
  21. training_params={
  22. "max_epochs": 1,
  23. "initial_lr": 0.1,
  24. "loss": "cross_entropy",
  25. "train_metrics_list": ["Accuracy"],
  26. "valid_metrics_list": ["Accuracy"],
  27. "metric_to_watch": "Accuracy",
  28. },
  29. )
  30. def test_train_cifar10_dataloader(self):
  31. super_gradients.init_trainer()
  32. trainer = Trainer("test", model_checkpoints_location="local")
  33. cifar10_train_dl, cifar10_val_dl = cifar10_train(), cifar10_val()
  34. model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
  35. trainer.train(
  36. model=model,
  37. training_params={
  38. "max_epochs": 1,
  39. "initial_lr": 0.1,
  40. "loss": "cross_entropy",
  41. "train_metrics_list": ["Accuracy"],
  42. "valid_metrics_list": ["Accuracy"],
  43. "metric_to_watch": "Accuracy",
  44. },
  45. train_loader=cifar10_train_dl,
  46. valid_loader=cifar10_val_dl,
  47. )
  48. def test_train_cifar100(self):
  49. super_gradients.init_trainer()
  50. trainer = Trainer("test", model_checkpoints_location="local")
  51. cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar100")
  52. trainer.connect_dataset_interface(cifar_10_dataset_interface)
  53. model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
  54. trainer.train(
  55. model=model,
  56. training_params={
  57. "max_epochs": 1,
  58. "initial_lr": 0.1,
  59. "loss": "cross_entropy",
  60. "train_metrics_list": ["Accuracy"],
  61. "valid_metrics_list": ["Accuracy"],
  62. "metric_to_watch": "Accuracy",
  63. },
  64. )
  65. def test_train_cifar100_dataloader(self):
  66. super_gradients.init_trainer()
  67. trainer = Trainer("test", model_checkpoints_location="local")
  68. cifar100_train_dl, cifar100_val_dl = cifar100_train(), cifar100_val()
  69. model = models.get("resnet18_cifar", arch_params={"num_classes": 100})
  70. trainer.train(
  71. model=model,
  72. training_params={
  73. "max_epochs": 1,
  74. "initial_lr": 0.1,
  75. "loss": "cross_entropy",
  76. "train_metrics_list": ["Accuracy"],
  77. "valid_metrics_list": ["Accuracy"],
  78. "metric_to_watch": "Accuracy",
  79. },
  80. train_loader=cifar100_train_dl,
  81. valid_loader=cifar100_val_dl,
  82. )
  83. if __name__ == "__main__":
  84. unittest.main()
Discard
@@ -6,6 +6,10 @@ from super_gradients.training.dataloaders.dataloader_factory import (
     classification_test_dataloader,
     classification_test_dataloader,
     detection_test_dataloader,
     detection_test_dataloader,
     segmentation_test_dataloader,
     segmentation_test_dataloader,
+    cifar10_val,
+    cifar10_train,
+    cifar100_val,
+    cifar100_train,
     coco2017_train,
     coco2017_train,
     coco2017_val,
     coco2017_val,
     coco2017_train_ssd_lite_mobilenet_v2,
     coco2017_train_ssd_lite_mobilenet_v2,
@@ -33,15 +37,23 @@ from super_gradients.training.dataloaders.dataloader_factory import (
     pascal_voc_segmentation_train,
     pascal_voc_segmentation_train,
     pascal_voc_segmentation_val,
     pascal_voc_segmentation_val,
     supervisely_persons_train,
     supervisely_persons_train,
-    supervisely_persons_val, pascal_voc_detection_train, pascal_voc_detection_val
+    supervisely_persons_val,
+    pascal_voc_detection_train,
+    pascal_voc_detection_val,
+)
+from super_gradients.training.datasets import (
+    COCODetectionDataset,
+    ImageNetDataset,
+    PascalAUG2012SegmentationDataSet,
+    PascalVOC2012SegmentationDataSet,
+    SuperviselyPersonsDataset,
+    PascalVOCDetectionDataset,
+    Cifar10,
+    Cifar100,
 )
 )
-from super_gradients.training.datasets import COCODetectionDataset, ImageNetDataset, PascalAUG2012SegmentationDataSet, \
-    PascalVOC2012SegmentationDataSet, \
-    SuperviselyPersonsDataset, PascalVOCDetectionDataset
 
 
 
 
 class DataLoaderFactoryTest(unittest.TestCase):
 class DataLoaderFactoryTest(unittest.TestCase):
-
     def test_coco2017_train_creation(self):
     def test_coco2017_train_creation(self):
         dl_train = coco2017_train()
         dl_train = coco2017_train()
         self.assertTrue(isinstance(dl_train, DataLoader))
         self.assertTrue(isinstance(dl_train, DataLoader))
@@ -124,7 +136,9 @@ class DataLoaderFactoryTest(unittest.TestCase):
 
 
     def test_imagenet_resnet50_kd_train_creation(self):
     def test_imagenet_resnet50_kd_train_creation(self):
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
-        dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
+        dl = imagenet_resnet50_kd_train(
+            dataloader_params={"sampler": {"InfiniteSampler": {}}}
+        )
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
 
 
@@ -153,6 +167,26 @@ class DataLoaderFactoryTest(unittest.TestCase):
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
 
 
+    def test_cifar10_train_creation(self):
+        dl_train = cifar10_train()
+        self.assertTrue(isinstance(dl_train, DataLoader))
+        self.assertTrue(isinstance(dl_train.dataset, Cifar10))
+
+    def test_cifar10_val_creation(self):
+        dl_val = cifar10_val()
+        self.assertTrue(isinstance(dl_val, DataLoader))
+        self.assertTrue(isinstance(dl_val.dataset, Cifar10))
+
+    def test_cifar100_train_creation(self):
+        dl_train = cifar100_train()
+        self.assertTrue(isinstance(dl_train, DataLoader))
+        self.assertTrue(isinstance(dl_train.dataset, Cifar100))
+
+    def test_cifar100_val_creation(self):
+        dl_val = cifar100_val()
+        self.assertTrue(isinstance(dl_val, DataLoader))
+        self.assertTrue(isinstance(dl_val.dataset, Cifar100))
+
     def test_classification_test_dataloader_creation(self):
     def test_classification_test_dataloader_creation(self):
         dl = classification_test_dataloader()
         dl = classification_test_dataloader()
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
@@ -209,5 +243,5 @@ class DataLoaderFactoryTest(unittest.TestCase):
         self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
         self.assertTrue(isinstance(dl.dataset, PascalVOCDetectionDataset))
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard