Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. from typing import Optional, Callable
  2. from torchvision.transforms import Compose
  3. from super_gradients.common.factories.list_factory import ListFactory
  4. from super_gradients.common.factories.transforms_factory import TransformsFactory
  5. from super_gradients.common.decorators.factory_decorator import resolve_param
  6. from torchvision.datasets import CIFAR10, CIFAR100
  7. class Cifar10(CIFAR10):
  8. """
  9. CIFAR10 Dataset
  10. :param root: Path for the data to be extracted
  11. :param train: Bool to load training (True) or validation (False) part of the dataset
  12. :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose
  13. :param target_transform: Transform to apply to target output
  14. :param download: Download (True) the dataset from source
  15. """
  16. @resolve_param("transforms", ListFactory(TransformsFactory()))
  17. def __init__(
  18. self,
  19. root: str,
  20. train: bool = True,
  21. transforms: Optional[Callable] = None,
  22. target_transform: Optional[Callable] = None,
  23. download: bool = False,
  24. ) -> None:
  25. super(Cifar10, self).__init__(
  26. root=root,
  27. train=train,
  28. transform=Compose(transforms),
  29. target_transform=target_transform,
  30. download=download,
  31. )
  32. class Cifar100(CIFAR100):
  33. @resolve_param("transforms", ListFactory(TransformsFactory()))
  34. def __init__(
  35. self,
  36. root: str,
  37. train: bool = True,
  38. transforms: Optional[Callable] = None,
  39. target_transform: Optional[Callable] = None,
  40. download: bool = False,
  41. ) -> None:
  42. """
  43. CIFAR100 Dataset
  44. :param root: Path for the data to be extracted
  45. :param train: Bool to load training (True) or validation (False) part of the dataset
  46. :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose
  47. :param target_transform: Transform to apply to target output
  48. :param download: Download (True) the dataset from source
  49. """
  50. super(Cifar100, self).__init__(
  51. root=root,
  52. train=train,
  53. transform=Compose(transforms),
  54. target_transform=target_transform,
  55. download=download,
  56. )
Discard
Tip!

Press p or to see the previous file or, n or to see the next file