Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stl10.py 7.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
  1. import os.path
  2. from pathlib import Path
  3. from typing import Any, Callable, cast, Optional, Tuple, Union
  4. import numpy as np
  5. from PIL import Image
  6. from .utils import check_integrity, download_and_extract_archive, verify_str_arg
  7. from .vision import VisionDataset
  8. class STL10(VisionDataset):
  9. """`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory of dataset where directory
  12. ``stl10_binary`` exists.
  13. split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
  14. Accordingly, dataset is selected.
  15. folds (int, optional): One of {0-9} or None.
  16. For training, loads one of the 10 pre-defined folds of 1k samples for the
  17. standard evaluation procedure. If no value is passed, loads the 5k samples.
  18. transform (callable, optional): A function/transform that takes in a PIL image
  19. and returns a transformed version. E.g, ``transforms.RandomCrop``
  20. target_transform (callable, optional): A function/transform that takes in the
  21. target and transforms it.
  22. download (bool, optional): If true, downloads the dataset from the internet and
  23. puts it in root directory. If dataset is already downloaded, it is not
  24. downloaded again.
  25. """
  26. base_folder = "stl10_binary"
  27. url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
  28. filename = "stl10_binary.tar.gz"
  29. tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
  30. class_names_file = "class_names.txt"
  31. folds_list_file = "fold_indices.txt"
  32. train_list = [
  33. ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
  34. ["train_y.bin", "5a34089d4802c674881badbb80307741"],
  35. ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
  36. ]
  37. test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
  38. splits = ("train", "train+unlabeled", "unlabeled", "test")
  39. def __init__(
  40. self,
  41. root: Union[str, Path],
  42. split: str = "train",
  43. folds: Optional[int] = None,
  44. transform: Optional[Callable] = None,
  45. target_transform: Optional[Callable] = None,
  46. download: bool = False,
  47. ) -> None:
  48. super().__init__(root, transform=transform, target_transform=target_transform)
  49. self.split = verify_str_arg(split, "split", self.splits)
  50. self.folds = self._verify_folds(folds)
  51. if download:
  52. self.download()
  53. elif not self._check_integrity():
  54. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  55. # now load the picked numpy arrays
  56. self.labels: Optional[np.ndarray]
  57. if self.split == "train":
  58. self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
  59. self.labels = cast(np.ndarray, self.labels)
  60. self.__load_folds(folds)
  61. elif self.split == "train+unlabeled":
  62. self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
  63. self.labels = cast(np.ndarray, self.labels)
  64. self.__load_folds(folds)
  65. unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
  66. self.data = np.concatenate((self.data, unlabeled_data))
  67. self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
  68. elif self.split == "unlabeled":
  69. self.data, _ = self.__loadfile(self.train_list[2][0])
  70. self.labels = np.asarray([-1] * self.data.shape[0])
  71. else: # self.split == 'test':
  72. self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
  73. class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
  74. if os.path.isfile(class_file):
  75. with open(class_file) as f:
  76. self.classes = f.read().splitlines()
  77. def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
  78. if folds is None:
  79. return folds
  80. elif isinstance(folds, int):
  81. if folds in range(10):
  82. return folds
  83. msg = "Value for argument folds should be in the range [0, 10), but got {}."
  84. raise ValueError(msg.format(folds))
  85. else:
  86. msg = "Expected type None or int for argument folds, but got type {}."
  87. raise ValueError(msg.format(type(folds)))
  88. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  89. """
  90. Args:
  91. index (int): Index
  92. Returns:
  93. tuple: (image, target) where target is index of the target class.
  94. """
  95. target: Optional[int]
  96. if self.labels is not None:
  97. img, target = self.data[index], int(self.labels[index])
  98. else:
  99. img, target = self.data[index], None
  100. # doing this so that it is consistent with all other datasets
  101. # to return a PIL Image
  102. img = Image.fromarray(np.transpose(img, (1, 2, 0)))
  103. if self.transform is not None:
  104. img = self.transform(img)
  105. if self.target_transform is not None:
  106. target = self.target_transform(target)
  107. return img, target
  108. def __len__(self) -> int:
  109. return self.data.shape[0]
  110. def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
  111. labels = None
  112. if labels_file:
  113. path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
  114. with open(path_to_labels, "rb") as f:
  115. labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
  116. path_to_data = os.path.join(self.root, self.base_folder, data_file)
  117. with open(path_to_data, "rb") as f:
  118. # read whole file in uint8 chunks
  119. everything = np.fromfile(f, dtype=np.uint8)
  120. images = np.reshape(everything, (-1, 3, 96, 96))
  121. images = np.transpose(images, (0, 1, 3, 2))
  122. return images, labels
  123. def _check_integrity(self) -> bool:
  124. for filename, md5 in self.train_list + self.test_list:
  125. fpath = os.path.join(self.root, self.base_folder, filename)
  126. if not check_integrity(fpath, md5):
  127. return False
  128. return True
  129. def download(self) -> None:
  130. if self._check_integrity():
  131. return
  132. download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
  133. self._check_integrity()
  134. def extra_repr(self) -> str:
  135. return "Split: {split}".format(**self.__dict__)
  136. def __load_folds(self, folds: Optional[int]) -> None:
  137. # loads one of the folds if specified
  138. if folds is None:
  139. return
  140. path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
  141. with open(path_to_folds) as f:
  142. str_idx = f.read().splitlines()[folds]
  143. list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
  144. self.data = self.data[list_idx, :, :, :]
  145. if self.labels is not None:
  146. self.labels = self.labels[list_idx]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...