Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sbu.py 4.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Tuple, Union
  4. from .folder import default_loader
  5. from .utils import check_integrity, download_and_extract_archive, download_url
  6. from .vision import VisionDataset
  7. class SBU(VisionDataset):
  8. """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
  9. Args:
  10. root (str or ``pathlib.Path``): Root directory of dataset where tarball
  11. ``SBUCaptionedPhotoDataset.tar.gz`` exists.
  12. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  13. and returns a transformed version. E.g, ``transforms.RandomCrop``
  14. target_transform (callable, optional): A function/transform that takes in the
  15. target and transforms it.
  16. download (bool, optional): If True, downloads the dataset from the internet and
  17. puts it in root directory. If dataset is already downloaded, it is not
  18. downloaded again.
  19. loader (callable, optional): A function to load an image given its path.
  20. By default, it uses PIL as its image loader, but users could also pass in
  21. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  22. """
  23. url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
  24. filename = "SBUCaptionedPhotoDataset.tar.gz"
  25. md5_checksum = "9aec147b3488753cf758b4d493422285"
  26. def __init__(
  27. self,
  28. root: Union[str, Path],
  29. transform: Optional[Callable] = None,
  30. target_transform: Optional[Callable] = None,
  31. download: bool = True,
  32. loader: Callable[[str], Any] = default_loader,
  33. ) -> None:
  34. super().__init__(root, transform=transform, target_transform=target_transform)
  35. self.loader = loader
  36. if download:
  37. self.download()
  38. if not self._check_integrity():
  39. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  40. # Read the caption for each photo
  41. self.photos = []
  42. self.captions = []
  43. file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
  44. file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
  45. for line1, line2 in zip(open(file1), open(file2)):
  46. url = line1.rstrip()
  47. photo = os.path.basename(url)
  48. filename = os.path.join(self.root, "dataset", photo)
  49. if os.path.exists(filename):
  50. caption = line2.rstrip()
  51. self.photos.append(photo)
  52. self.captions.append(caption)
  53. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  54. """
  55. Args:
  56. index (int): Index
  57. Returns:
  58. tuple: (image, target) where target is a caption for the photo.
  59. """
  60. filename = os.path.join(self.root, "dataset", self.photos[index])
  61. img = self.loader(filename)
  62. if self.transform is not None:
  63. img = self.transform(img)
  64. target = self.captions[index]
  65. if self.target_transform is not None:
  66. target = self.target_transform(target)
  67. return img, target
  68. def __len__(self) -> int:
  69. """The number of photos in the dataset."""
  70. return len(self.photos)
  71. def _check_integrity(self) -> bool:
  72. """Check the md5 checksum of the downloaded tarball."""
  73. root = self.root
  74. fpath = os.path.join(root, self.filename)
  75. if not check_integrity(fpath, self.md5_checksum):
  76. return False
  77. return True
  78. def download(self) -> None:
  79. """Download and extract the tarball, and download each individual photo."""
  80. if self._check_integrity():
  81. return
  82. download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
  83. # Download individual photos
  84. with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
  85. for line in fh:
  86. url = line.rstrip()
  87. try:
  88. download_url(url, os.path.join(self.root, "dataset"))
  89. except OSError:
  90. # The images point to public images on Flickr.
  91. # Note: Images might be removed by users at anytime.
  92. pass
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...