Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#495 Feature/sg 416 albumentations plugin for classification

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-416_albumentations_plugin_for_classification
@@ -4,20 +4,36 @@ from omegaconf import ListConfig
 
 
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.base_factory import BaseFactory
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.list_factory import ListFactory
-from super_gradients.training.transforms import TRANSFORMS
+from super_gradients.training.transforms import TRANSFORMS, ALBUMENTATIONS_TRANSFORMS, ALBUMENTATIONS_COMP_TRANSFORMS, imported_albumentations_failure
+from super_gradients.training.transforms.pipeline_adaptors import AlbumentationsAdaptor
 
 
 
 
 class TransformsFactory(BaseFactory):
 class TransformsFactory(BaseFactory):
-
     def __init__(self):
     def __init__(self):
         super().__init__(TRANSFORMS)
         super().__init__(TRANSFORMS)
 
 
     def get(self, conf: Union[str, dict]):
     def get(self, conf: Union[str, dict]):
 
 
-        # SPECIAL HANDLING FOR COMPOSE
-        if isinstance(conf, Mapping) and 'Compose' in conf:
-            conf['Compose']['transforms'] = ListFactory(TransformsFactory()).get(conf['Compose']['transforms'])
+        # SPECIAL HANDLING FOR COMPOSE AND ALBUMENTATIONS
+        if isinstance(conf, Mapping) and "Albumentations" in conf:
+            return AlbumentationsAdaptor(AlbumentationsTransformsFactory().get(conf["Albumentations"]))
+        if isinstance(conf, Mapping) and "Compose" in conf:
+            conf["Compose"]["transforms"] = ListFactory(TransformsFactory()).get(conf["Compose"]["transforms"])
         elif isinstance(conf, (list, ListConfig)):
         elif isinstance(conf, (list, ListConfig)):
             conf = ListFactory(TransformsFactory()).get(conf)
             conf = ListFactory(TransformsFactory()).get(conf)
 
 
         return super().get(conf)
         return super().get(conf)
+
+
+class AlbumentationsTransformsFactory(BaseFactory):
+    def __init__(self):
+        if imported_albumentations_failure:
+            raise imported_albumentations_failure
+        super().__init__(ALBUMENTATIONS_TRANSFORMS)
+
+    def get(self, conf: Union[str, dict]):
+        if isinstance(conf, Mapping):
+            _type = list(conf.keys())[0]  # THE TYPE NAME
+            if _type in ALBUMENTATIONS_COMP_TRANSFORMS:
+                conf[_type]["transforms"] = ListFactory(AlbumentationsTransformsFactory()).get(conf[_type]["transforms"])
+        return super(AlbumentationsTransformsFactory, self).get(conf)
Discard
@@ -5,7 +5,9 @@
 #   1. Move to the project root (where you will find the ReadMe and src folder)
 #   1. Move to the project root (where you will find the ReadMe and src folder)
 #   2. Run the command:
 #   2. Run the command:
 #       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet +experiment_name=cifar10
 #       python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet +experiment_name=cifar10
-
+#
+#   To use equivalent Albumentations transforms pipeline set dataset_params to cifar10_albumentations_dataset_params:
+#     python src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet dataset_params=cifar10_albumentations_dataset_params
 defaults:
 defaults:
   - training_hyperparams: cifar10_resnet_train_params
   - training_hyperparams: cifar10_resnet_train_params
   - dataset_params: cifar10_dataset_params
   - dataset_params: cifar10_dataset_params
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. # Equivalent to cifar10_dataset_params.yaml, but uses albumentations transforms.
  2. # The purpose of the below configuration is to demonstrate the use of Albumentation transforms in train_from_recipe.
  3. batch_size: 256 # batch size for trainset
  4. val_batch_size: 512 # batch size for valset in DatasetInterface
  5. # TODO: REMOVE ABOVE, HERE FOR COMPATIBILITY UNTIL WE REMOVE DATASET_INTERFACE
  6. train_dataset_params:
  7. root: ./data/cifar10
  8. train: True
  9. transforms:
  10. Albumentations:
  11. Compose:
  12. transforms:
  13. - RandomCrop:
  14. height: 32
  15. width: 32
  16. - HorizontalFlip:
  17. p: 0.5
  18. - Normalize:
  19. mean:
  20. - 0.4914
  21. - 0.4822
  22. - 0.4465
  23. std:
  24. - 0.2023
  25. - 0.1994
  26. - 0.2010
  27. - ToTensorV2
  28. target_transform: null
  29. download: True
  30. train_dataloader_params:
  31. batch_size: 256
  32. num_workers: 8
  33. drop_last: False
  34. pin_memory: True
  35. val_dataset_params:
  36. root: ./data/cifar10
  37. train: False
  38. transforms:
  39. Albumentations:
  40. Compose:
  41. transforms:
  42. - Normalize:
  43. mean:
  44. - 0.4914
  45. - 0.4822
  46. - 0.4465
  47. std:
  48. - 0.2023
  49. - 0.1994
  50. - 0.2010
  51. - ToTensorV2
  52. target_transform: null
  53. download: True
  54. val_dataloader_params:
  55. batch_size: 512
  56. num_workers: 8
  57. drop_last: False
  58. pin_memory: True
Discard
@@ -1,8 +1,7 @@
-from typing import Optional, Callable
+from typing import Optional, Callable, Union
 
 
 from torchvision.transforms import Compose
 from torchvision.transforms import Compose
 
 
-from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from torchvision.datasets import CIFAR10, CIFAR100
 from torchvision.datasets import CIFAR10, CIFAR100
@@ -18,31 +17,37 @@ class Cifar10(CIFAR10):
     :param target_transform:        Transform to apply to target output
     :param target_transform:        Transform to apply to target output
     :param download:                Download (True) the dataset from source
     :param download:                Download (True) the dataset from source
     """
     """
-    @resolve_param("transforms", ListFactory(TransformsFactory()))
+
+    @resolve_param("transforms", TransformsFactory())
     def __init__(
     def __init__(
         self,
         self,
         root: str,
         root: str,
         train: bool = True,
         train: bool = True,
-        transforms: Optional[Callable] = None,
+        transforms: Union[list, dict] = None,
         target_transform: Optional[Callable] = None,
         target_transform: Optional[Callable] = None,
         download: bool = False,
         download: bool = False,
     ) -> None:
     ) -> None:
+        # TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS
+        # TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)
+        if isinstance(transforms, list):
+            transforms = Compose(transforms)
+
         super(Cifar10, self).__init__(
         super(Cifar10, self).__init__(
             root=root,
             root=root,
             train=train,
             train=train,
-            transform=Compose(transforms),
+            transform=transforms,
             target_transform=target_transform,
             target_transform=target_transform,
             download=download,
             download=download,
         )
         )
 
 
 
 
 class Cifar100(CIFAR100):
 class Cifar100(CIFAR100):
-    @resolve_param("transforms", ListFactory(TransformsFactory()))
+    @resolve_param("transforms", TransformsFactory())
     def __init__(
     def __init__(
         self,
         self,
         root: str,
         root: str,
         train: bool = True,
         train: bool = True,
-        transforms: Optional[Callable] = None,
+        transforms: Union[list, dict] = None,
         target_transform: Optional[Callable] = None,
         target_transform: Optional[Callable] = None,
         download: bool = False,
         download: bool = False,
     ) -> None:
     ) -> None:
@@ -55,10 +60,15 @@ class Cifar100(CIFAR100):
         :param target_transform:        Transform to apply to target output
         :param target_transform:        Transform to apply to target output
         :param download:                Download (True) the dataset from source
         :param download:                Download (True) the dataset from source
         """
         """
+        # TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS
+        # TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)
+        if isinstance(transforms, list):
+            transforms = Compose(transforms)
+
         super(Cifar100, self).__init__(
         super(Cifar100, self).__init__(
             root=root,
             root=root,
             train=train,
             train=train,
-            transform=Compose(transforms),
+            transform=transforms,
             target_transform=target_transform,
             target_transform=target_transform,
             download=download,
             download=download,
         )
         )
Discard
@@ -1,14 +1,19 @@
+from typing import Union
+
 import torchvision.datasets as torch_datasets
 import torchvision.datasets as torch_datasets
 from torchvision.transforms import Compose
 from torchvision.transforms import Compose
 
 
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.factories.transforms_factory import TransformsFactory
 from super_gradients.common.factories.transforms_factory import TransformsFactory
-from super_gradients.common.factories.list_factory import ListFactory
 
 
 
 
 class ImageNetDataset(torch_datasets.ImageFolder):
 class ImageNetDataset(torch_datasets.ImageFolder):
     """ImageNetDataset dataset"""
     """ImageNetDataset dataset"""
 
 
-    @resolve_param('transforms', factory=ListFactory(TransformsFactory()))
-    def __init__(self, root: str, transforms: list = [], *args, **kwargs):
-        super(ImageNetDataset, self).__init__(root, transform=Compose(transforms), *args, **kwargs)
+    @resolve_param("transforms", factory=TransformsFactory())
+    def __init__(self, root: str, transforms: Union[list, dict] = [], *args, **kwargs):
+        # TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS
+        # TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)
+        if isinstance(transforms, list):
+            transforms = Compose(transforms)
+        super(ImageNetDataset, self).__init__(root, transform=transforms, *args, **kwargs)
Discard
@@ -1,10 +1,31 @@
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 # PACKAGE IMPORTS FOR EXTERNAL USAGE
 import cv2
 import cv2
-from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionRandomAffine, DetectionHSV,\
-    DetectionPaddedRescale, DetectionTargetsFormatTransform
-from super_gradients.training.transforms.all_transforms import TRANSFORMS, Transforms
+from super_gradients.training.transforms.transforms import (
+    DetectionMosaic,
+    DetectionRandomAffine,
+    DetectionHSV,
+    DetectionPaddedRescale,
+    DetectionTargetsFormatTransform,
+)
+from super_gradients.training.transforms.all_transforms import (
+    TRANSFORMS,
+    ALBUMENTATIONS_TRANSFORMS,
+    Transforms,
+    ALBUMENTATIONS_COMP_TRANSFORMS,
+    imported_albumentations_failure,
+)
 
 
-__all__ = ['TRANSFORMS', 'Transforms', 'DetectionMosaic', 'DetectionRandomAffine', 'DetectionHSV', 'DetectionPaddedRescale',
-           'DetectionTargetsFormatTransform']
+__all__ = [
+    "TRANSFORMS",
+    "ALBUMENTATIONS_TRANSFORMS",
+    "ALBUMENTATIONS_COMP_TRANSFORMS",
+    "Transforms",
+    "DetectionMosaic",
+    "DetectionRandomAffine",
+    "DetectionHSV",
+    "DetectionPaddedRescale",
+    "DetectionTargetsFormatTransform",
+    "imported_albumentations_failure",
+]
 
 
 cv2.setNumThreads(0)
 cv2.setNumThreads(0)
Discard
@@ -1,6 +1,12 @@
 from super_gradients.common.object_names import Transforms
 from super_gradients.common.object_names import Transforms
 from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
 from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
 from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, rand_augment_transform
 from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, rand_augment_transform
+import importlib
+import inspect
+
+from super_gradients.common.abstractions.abstract_logger import get_logger
+
+
 from super_gradients.training.transforms.transforms import (
 from super_gradients.training.transforms.transforms import (
     SegRandomFlip,
     SegRandomFlip,
     SegRescale,
     SegRescale,
@@ -118,3 +124,31 @@ TRANSFORMS = {
     Transforms.RandomAutocontrast: RandomAutocontrast,
     Transforms.RandomAutocontrast: RandomAutocontrast,
     Transforms.RandomEqualize: RandomEqualize,
     Transforms.RandomEqualize: RandomEqualize,
 }
 }
+logger = get_logger(__name__)
+
+try:
+    from albumentations import BasicTransform, BaseCompose
+
+    imported_albumentations_failure = None
+except (ImportError, NameError, ModuleNotFoundError) as import_err:
+    logger.debug("Failed to import pytorch_quantization")
+    imported_albumentations_failure = import_err
+
+if imported_albumentations_failure is None:
+    ALBUMENTATIONS_TRANSFORMS = {
+        name: cls for name, cls in inspect.getmembers(importlib.import_module("albumentations"), inspect.isclass) if issubclass(cls, BasicTransform)
+    }
+    ALBUMENTATIONS_TRANSFORMS.update(
+        {name: cls for name, cls in inspect.getmembers(importlib.import_module("albumentations.pytorch"), inspect.isclass) if issubclass(cls, BasicTransform)}
+    )
+
+    ALBUMENTATIONS_COMP_TRANSFORMS = {
+        name: cls
+        for name, cls in inspect.getmembers(importlib.import_module("albumentations.core.composition"), inspect.isclass)
+        if issubclass(cls, BaseCompose)
+    }
+    ALBUMENTATIONS_TRANSFORMS.update(ALBUMENTATIONS_COMP_TRANSFORMS)
+
+else:
+    ALBUMENTATIONS_TRANSFORMS = None
+    ALBUMENTATIONS_COMP_TRANSFORMS = None
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  1. from typing import Callable
  2. from abc import abstractmethod, ABC
  3. import numpy as np
  4. class TransformsPipelineAdaptorBase(ABC):
  5. def __init__(self, composed_transforms: Callable):
  6. self.composed_transforms = composed_transforms
  7. @abstractmethod
  8. def __call__(self, sample, *args, **kwargs):
  9. raise NotImplementedError
  10. @abstractmethod
  11. def prep_for_transforms(self, sample):
  12. raise NotImplementedError
  13. @abstractmethod
  14. def post_transforms_processing(self, sample):
  15. raise NotImplementedError
  16. class AlbumentationsAdaptor(TransformsPipelineAdaptorBase):
  17. def __init__(self, composed_transforms: Callable):
  18. super(AlbumentationsAdaptor, self).__init__(composed_transforms)
  19. def __call__(self, sample, *args, **kwargs):
  20. sample = self.prep_for_transforms(sample)
  21. sample = self.composed_transforms(**sample)["image"]
  22. sample = self.post_transforms_processing(sample)
  23. return sample
  24. def prep_for_transforms(self, sample):
  25. return {"image": np.array(sample)}
  26. def post_transforms_processing(self, sample):
  27. return sample
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import unittest
  2. import numpy as np
  3. from super_gradients.training.datasets import Cifar10, Cifar100, ImageNetDataset
  4. from albumentations import Compose, HorizontalFlip, InvertImg
  5. class AlbumentationsIntegrationTest(unittest.TestCase):
  6. def _apply_aug(self, img_no_aug):
  7. pipe = Compose(transforms=[HorizontalFlip(p=1.0), InvertImg(p=1.0)])
  8. img_no_aug_transformed = pipe(image=np.array(img_no_aug))["image"]
  9. return img_no_aug_transformed
  10. def test_cifar10_albumentations_integration(self):
  11. ds_no_aug = Cifar10(root="./data/cifar10", train=True, download=True)
  12. img_no_aug, _ = ds_no_aug.__getitem__(0)
  13. ds = Cifar10(
  14. root="./data/cifar10",
  15. train=True,
  16. download=True,
  17. transforms={"Albumentations": {"Compose": {"transforms": [{"HorizontalFlip": {"p": 1.0}}, {"InvertImg": {"p": 1.0}}]}}},
  18. )
  19. img_aug, _ = ds.__getitem__(0)
  20. img_no_aug_transformed = self._apply_aug(img_no_aug)
  21. self.assertTrue(np.allclose(img_no_aug_transformed, img_aug))
  22. def test_cifar100_albumentations_integration(self):
  23. ds_no_aug = Cifar100(root="./data/cifar100", train=True, download=True)
  24. img_no_aug, _ = ds_no_aug.__getitem__(0)
  25. ds = Cifar100(
  26. root="./data/cifar100",
  27. train=True,
  28. download=True,
  29. transforms={"Albumentations": {"Compose": {"transforms": [{"HorizontalFlip": {"p": 1}}, {"InvertImg": {"p": 1.0}}]}}},
  30. )
  31. img_aug, _ = ds.__getitem__(0)
  32. img_no_aug_transformed = self._apply_aug(img_no_aug)
  33. self.assertTrue(np.allclose(img_no_aug_transformed, img_aug))
  34. def test_imagenet_albumentations_integration(self):
  35. ds_no_aug = ImageNetDataset(root="/data/Imagenet/val")
  36. img_no_aug, _ = ds_no_aug.__getitem__(0)
  37. ds = ImageNetDataset(
  38. root="/data/Imagenet/val", transforms={"Albumentations": {"Compose": {"transforms": [{"HorizontalFlip": {"p": 1}}, {"InvertImg": {"p": 1.0}}]}}}
  39. )
  40. img_aug, _ = ds.__getitem__(0)
  41. img_no_aug_transformed = self._apply_aug(img_no_aug)
  42. self.assertTrue(np.allclose(img_no_aug_transformed, img_aug))
  43. if __name__ == "__main__":
  44. unittest.main()
Discard