Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import inspect
  2. from typing import Union, Mapping
  3. from omegaconf import ListConfig
  4. from torchvision import transforms
  5. from super_gradients.common.factories.base_factory import BaseFactory
  6. from super_gradients.common.factories.list_factory import ListFactory
  7. from super_gradients.training.datasets.data_augmentation import Lighting, RandomErase
  8. from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation, rand_augment_transform
  9. from super_gradients.training.transforms.transforms import RandomFlip, Rescale, RandomRescale, RandomRotate, \
  10. CropImageAndMask, RandomGaussianBlur, PadShortToCropSize, ResizeSeg, ColorJitterSeg, DetectionMosaic, DetectionRandomAffine, \
  11. DetectionMixup, DetectionHSV, \
  12. DetectionHorizontalFlip, DetectionTargetsFormat, DetectionPaddedRescale, \
  13. DetectionTargetsFormatTransform
  14. class TransformsFactory(BaseFactory):
  15. def __init__(self):
  16. type_dict = {
  17. 'RandomFlipSeg': RandomFlip,
  18. 'ResizeSeg': ResizeSeg,
  19. 'RescaleSeg': Rescale,
  20. 'RandomRescaleSeg': RandomRescale,
  21. 'RandomRotateSeg': RandomRotate,
  22. 'CropImageAndMaskSeg': CropImageAndMask,
  23. 'RandomGaussianBlurSeg': RandomGaussianBlur,
  24. 'PadShortToCropSizeSeg': PadShortToCropSize,
  25. 'ColorJitterSeg': ColorJitterSeg,
  26. "DetectionMosaic": DetectionMosaic,
  27. "DetectionRandomAffine": DetectionRandomAffine,
  28. "DetectionMixup": DetectionMixup,
  29. "DetectionHSV": DetectionHSV,
  30. "DetectionHorizontalFlip": DetectionHorizontalFlip,
  31. "DetectionPaddedRescale": DetectionPaddedRescale,
  32. "DetectionTargetsFormat": DetectionTargetsFormat,
  33. "DetectionTargetsFormatTransform": DetectionTargetsFormatTransform,
  34. 'RandomResizedCropAndInterpolation': RandomResizedCropAndInterpolation,
  35. 'RandAugmentTransform': rand_augment_transform,
  36. 'Lighting': Lighting,
  37. 'RandomErase': RandomErase
  38. }
  39. for name, obj in inspect.getmembers(transforms, inspect.isclass):
  40. if name in type_dict:
  41. raise RuntimeError(f'key {name} already exists in dictionary')
  42. type_dict[name] = obj
  43. super().__init__(type_dict)
  44. def get(self, conf: Union[str, dict]):
  45. # SPECIAL HANDLING FOR COMPOSE
  46. if isinstance(conf, Mapping) and 'Compose' in conf:
  47. conf['Compose']['transforms'] = ListFactory(TransformsFactory()).get(conf['Compose']['transforms'])
  48. elif isinstance(conf, (list, ListConfig)):
  49. conf = ListFactory(TransformsFactory()).get(conf)
  50. return super().get(conf)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file