Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

vision.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Callable, List, Optional, Tuple, Union
  4. import torch.utils.data as data
  5. from ..utils import _log_api_usage_once
  6. class VisionDataset(data.Dataset):
  7. """
  8. Base Class For making datasets which are compatible with torchvision.
  9. It is necessary to override the ``__getitem__`` and ``__len__`` method.
  10. Args:
  11. root (string, optional): Root directory of dataset. Only used for `__repr__`.
  12. transforms (callable, optional): A function/transforms that takes in
  13. an image and a label and returns the transformed versions of both.
  14. transform (callable, optional): A function/transform that takes in a PIL image
  15. and returns a transformed version. E.g, ``transforms.RandomCrop``
  16. target_transform (callable, optional): A function/transform that takes in the
  17. target and transforms it.
  18. .. note::
  19. :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
  20. """
  21. _repr_indent = 4
  22. def __init__(
  23. self,
  24. root: Union[str, Path] = None, # type: ignore[assignment]
  25. transforms: Optional[Callable] = None,
  26. transform: Optional[Callable] = None,
  27. target_transform: Optional[Callable] = None,
  28. ) -> None:
  29. _log_api_usage_once(self)
  30. if isinstance(root, str):
  31. root = os.path.expanduser(root)
  32. self.root = root
  33. has_transforms = transforms is not None
  34. has_separate_transform = transform is not None or target_transform is not None
  35. if has_transforms and has_separate_transform:
  36. raise ValueError("Only transforms or transform/target_transform can be passed as argument")
  37. # for backwards-compatibility
  38. self.transform = transform
  39. self.target_transform = target_transform
  40. if has_separate_transform:
  41. transforms = StandardTransform(transform, target_transform)
  42. self.transforms = transforms
  43. def __getitem__(self, index: int) -> Any:
  44. """
  45. Args:
  46. index (int): Index
  47. Returns:
  48. (Any): Sample and meta data, optionally transformed by the respective transforms.
  49. """
  50. raise NotImplementedError
  51. def __len__(self) -> int:
  52. raise NotImplementedError
  53. def __repr__(self) -> str:
  54. head = "Dataset " + self.__class__.__name__
  55. body = [f"Number of datapoints: {self.__len__()}"]
  56. if self.root is not None:
  57. body.append(f"Root location: {self.root}")
  58. body += self.extra_repr().splitlines()
  59. if hasattr(self, "transforms") and self.transforms is not None:
  60. body += [repr(self.transforms)]
  61. lines = [head] + [" " * self._repr_indent + line for line in body]
  62. return "\n".join(lines)
  63. def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
  64. lines = transform.__repr__().splitlines()
  65. return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
  66. def extra_repr(self) -> str:
  67. return ""
  68. class StandardTransform:
  69. def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
  70. self.transform = transform
  71. self.target_transform = target_transform
  72. def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
  73. if self.transform is not None:
  74. input = self.transform(input)
  75. if self.target_transform is not None:
  76. target = self.target_transform(target)
  77. return input, target
  78. def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
  79. lines = transform.__repr__().splitlines()
  80. return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
  81. def __repr__(self) -> str:
  82. body = [self.__class__.__name__]
  83. if self.transform is not None:
  84. body += self._format_transform_repr(self.transform, "Transform: ")
  85. if self.target_transform is not None:
  86. body += self._format_transform_repr(self.target_transform, "Target transform: ")
  87. return "\n".join(body)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...