Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

omniglot.py 4.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  1. from os.path import join
  2. from pathlib import Path
  3. from typing import Any, Callable, List, Optional, Tuple, Union
  4. from PIL import Image
  5. from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
  6. from .vision import VisionDataset
  7. class Omniglot(VisionDataset):
  8. """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
  9. Args:
  10. root (str or ``pathlib.Path``): Root directory of dataset where directory
  11. ``omniglot-py`` exists.
  12. background (bool, optional): If True, creates dataset from the "background" set, otherwise
  13. creates from the "evaluation" set. This terminology is defined by the authors.
  14. transform (callable, optional): A function/transform that takes in a PIL image
  15. and returns a transformed version. E.g, ``transforms.RandomCrop``
  16. target_transform (callable, optional): A function/transform that takes in the
  17. target and transforms it.
  18. download (bool, optional): If true, downloads the dataset zip files from the internet and
  19. puts it in root directory. If the zip files are already downloaded, they are not
  20. downloaded again.
  21. loader (callable, optional): A function to load an image given its path.
  22. By default, it uses PIL as its image loader, but users could also pass in
  23. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  24. """
  25. folder = "omniglot-py"
  26. download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
  27. zips_md5 = {
  28. "images_background": "68d2efa1b9178cc56df9314c21c6e718",
  29. "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
  30. }
  31. def __init__(
  32. self,
  33. root: Union[str, Path],
  34. background: bool = True,
  35. transform: Optional[Callable] = None,
  36. target_transform: Optional[Callable] = None,
  37. download: bool = False,
  38. loader: Optional[Callable[[Union[str, Path]], Any]] = None,
  39. ) -> None:
  40. super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
  41. self.background = background
  42. if download:
  43. self.download()
  44. if not self._check_integrity():
  45. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  46. self.target_folder = join(self.root, self._get_target_folder())
  47. self._alphabets = list_dir(self.target_folder)
  48. self._characters: List[str] = sum(
  49. ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
  50. )
  51. self._character_images = [
  52. [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
  53. for idx, character in enumerate(self._characters)
  54. ]
  55. self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
  56. self.loader = loader
  57. def __len__(self) -> int:
  58. return len(self._flat_character_images)
  59. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  60. """
  61. Args:
  62. index (int): Index
  63. Returns:
  64. tuple: (image, target) where target is index of the target character class.
  65. """
  66. image_name, character_class = self._flat_character_images[index]
  67. image_path = join(self.target_folder, self._characters[character_class], image_name)
  68. image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path)
  69. if self.transform:
  70. image = self.transform(image)
  71. if self.target_transform:
  72. character_class = self.target_transform(character_class)
  73. return image, character_class
  74. def _check_integrity(self) -> bool:
  75. zip_filename = self._get_target_folder()
  76. if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
  77. return False
  78. return True
  79. def download(self) -> None:
  80. if self._check_integrity():
  81. return
  82. filename = self._get_target_folder()
  83. zip_filename = filename + ".zip"
  84. url = self.download_url_prefix + "/" + zip_filename
  85. download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
  86. def _get_target_folder(self) -> str:
  87. return "images_background" if self.background else "images_evaluation"
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...