Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dtd.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  1. import os
  2. import pathlib
  3. from typing import Any, Callable, Optional, Tuple, Union
  4. from .folder import default_loader
  5. from .utils import download_and_extract_archive, verify_str_arg
  6. from .vision import VisionDataset
  7. class DTD(VisionDataset):
  8. """`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
  9. Args:
  10. root (str or ``pathlib.Path``): Root directory of the dataset.
  11. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
  12. partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
  13. .. note::
  14. The partition only changes which split each image belongs to. Thus, regardless of the selected
  15. partition, combining all splits will result in all images.
  16. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  17. and returns a transformed version. E.g, ``transforms.RandomCrop``
  18. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  19. download (bool, optional): If True, downloads the dataset from the internet and
  20. puts it in root directory. If dataset is already downloaded, it is not
  21. downloaded again. Default is False.
  22. loader (callable, optional): A function to load an image given its path.
  23. By default, it uses PIL as its image loader, but users could also pass in
  24. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  25. """
  26. _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
  27. _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
  28. def __init__(
  29. self,
  30. root: Union[str, pathlib.Path],
  31. split: str = "train",
  32. partition: int = 1,
  33. transform: Optional[Callable] = None,
  34. target_transform: Optional[Callable] = None,
  35. download: bool = False,
  36. loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
  37. ) -> None:
  38. self._split = verify_str_arg(split, "split", ("train", "val", "test"))
  39. if not isinstance(partition, int) and not (1 <= partition <= 10):
  40. raise ValueError(
  41. f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
  42. f"but got {partition} instead"
  43. )
  44. self._partition = partition
  45. super().__init__(root, transform=transform, target_transform=target_transform)
  46. self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
  47. self._data_folder = self._base_folder / "dtd"
  48. self._meta_folder = self._data_folder / "labels"
  49. self._images_folder = self._data_folder / "images"
  50. if download:
  51. self._download()
  52. if not self._check_exists():
  53. raise RuntimeError("Dataset not found. You can use download=True to download it")
  54. self._image_files = []
  55. classes = []
  56. with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
  57. for line in file:
  58. cls, name = line.strip().split("/")
  59. self._image_files.append(self._images_folder.joinpath(cls, name))
  60. classes.append(cls)
  61. self.classes = sorted(set(classes))
  62. self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
  63. self._labels = [self.class_to_idx[cls] for cls in classes]
  64. self.loader = loader
  65. def __len__(self) -> int:
  66. return len(self._image_files)
  67. def __getitem__(self, idx: int) -> Tuple[Any, Any]:
  68. image_file, label = self._image_files[idx], self._labels[idx]
  69. image = self.loader(image_file)
  70. if self.transform:
  71. image = self.transform(image)
  72. if self.target_transform:
  73. label = self.target_transform(label)
  74. return image, label
  75. def extra_repr(self) -> str:
  76. return f"split={self._split}, partition={self._partition}"
  77. def _check_exists(self) -> bool:
  78. return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
  79. def _download(self) -> None:
  80. if self._check_exists():
  81. return
  82. download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...