1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
- from os.path import join
- from pathlib import Path
- from typing import Any, Callable, List, Optional, Tuple, Union
- from PIL import Image
- from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
- from .vision import VisionDataset
- class Omniglot(VisionDataset):
- """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
- Args:
- root (str or ``pathlib.Path``): Root directory of dataset where directory
- ``omniglot-py`` exists.
- background (bool, optional): If True, creates dataset from the "background" set, otherwise
- creates from the "evaluation" set. This terminology is defined by the authors.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- download (bool, optional): If true, downloads the dataset zip files from the internet and
- puts it in root directory. If the zip files are already downloaded, they are not
- downloaded again.
- loader (callable, optional): A function to load an image given its path.
- By default, it uses PIL as its image loader, but users could also pass in
- ``torchvision.io.decode_image`` for decoding image data into tensors directly.
- """
- folder = "omniglot-py"
- download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
- zips_md5 = {
- "images_background": "68d2efa1b9178cc56df9314c21c6e718",
- "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
- }
- def __init__(
- self,
- root: Union[str, Path],
- background: bool = True,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- loader: Optional[Callable[[Union[str, Path]], Any]] = None,
- ) -> None:
- super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
- self.background = background
- if download:
- self.download()
- if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
- self.target_folder = join(self.root, self._get_target_folder())
- self._alphabets = list_dir(self.target_folder)
- self._characters: List[str] = sum(
- ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
- )
- self._character_images = [
- [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
- for idx, character in enumerate(self._characters)
- ]
- self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
- self.loader = loader
- def __len__(self) -> int:
- return len(self._flat_character_images)
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is index of the target character class.
- """
- image_name, character_class = self._flat_character_images[index]
- image_path = join(self.target_folder, self._characters[character_class], image_name)
- image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path)
- if self.transform:
- image = self.transform(image)
- if self.target_transform:
- character_class = self.target_transform(character_class)
- return image, character_class
- def _check_integrity(self) -> bool:
- zip_filename = self._get_target_folder()
- if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
- return False
- return True
- def download(self) -> None:
- if self._check_integrity():
- return
- filename = self._get_target_folder()
- zip_filename = filename + ".zip"
- url = self.download_url_prefix + "/" + zip_filename
- download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
- def _get_target_folder(self) -> str:
- return "images_background" if self.background else "images_evaluation"
|