Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sun397.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. from pathlib import Path
  2. from typing import Any, Callable, Optional, Tuple, Union
  3. from .folder import default_loader
  4. from .utils import download_and_extract_archive
  5. from .vision import VisionDataset
  6. class SUN397(VisionDataset):
  7. """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
  8. The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
  9. 397 categories with 108'754 images.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory of the dataset.
  12. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  13. and returns a transformed version. E.g, ``transforms.RandomCrop``
  14. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  15. download (bool, optional): If true, downloads the dataset from the internet and
  16. puts it in root directory. If dataset is already downloaded, it is not
  17. downloaded again.
  18. loader (callable, optional): A function to load an image given its path.
  19. By default, it uses PIL as its image loader, but users could also pass in
  20. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  21. """
  22. _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
  23. _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
  24. def __init__(
  25. self,
  26. root: Union[str, Path],
  27. transform: Optional[Callable] = None,
  28. target_transform: Optional[Callable] = None,
  29. download: bool = False,
  30. loader: Callable[[Union[str, Path]], Any] = default_loader,
  31. ) -> None:
  32. super().__init__(root, transform=transform, target_transform=target_transform)
  33. self._data_dir = Path(self.root) / "SUN397"
  34. if download:
  35. self._download()
  36. if not self._check_exists():
  37. raise RuntimeError("Dataset not found. You can use download=True to download it")
  38. with open(self._data_dir / "ClassName.txt") as f:
  39. self.classes = [c[3:].strip() for c in f]
  40. self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
  41. self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
  42. self._labels = [
  43. self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
  44. ]
  45. self.loader = loader
  46. def __len__(self) -> int:
  47. return len(self._image_files)
  48. def __getitem__(self, idx: int) -> Tuple[Any, Any]:
  49. image_file, label = self._image_files[idx], self._labels[idx]
  50. image = self.loader(image_file)
  51. if self.transform:
  52. image = self.transform(image)
  53. if self.target_transform:
  54. label = self.target_transform(label)
  55. return image, label
  56. def _check_exists(self) -> bool:
  57. return self._data_dir.is_dir()
  58. def _download(self) -> None:
  59. if self._check_exists():
  60. return
  61. download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...