Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

semeion.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  1. import os.path
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Tuple, Union
  4. import numpy as np
  5. from PIL import Image
  6. from .utils import check_integrity, download_url
  7. from .vision import VisionDataset
  8. class SEMEION(VisionDataset):
  9. r"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory of dataset where directory
  12. ``semeion.py`` exists.
  13. transform (callable, optional): A function/transform that takes in a PIL image
  14. and returns a transformed version. E.g, ``transforms.RandomCrop``
  15. target_transform (callable, optional): A function/transform that takes in the
  16. target and transforms it.
  17. download (bool, optional): If true, downloads the dataset from the internet and
  18. puts it in root directory. If dataset is already downloaded, it is not
  19. downloaded again.
  20. """
  21. url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
  22. filename = "semeion.data"
  23. md5_checksum = "cb545d371d2ce14ec121470795a77432"
  24. def __init__(
  25. self,
  26. root: Union[str, Path],
  27. transform: Optional[Callable] = None,
  28. target_transform: Optional[Callable] = None,
  29. download: bool = True,
  30. ) -> None:
  31. super().__init__(root, transform=transform, target_transform=target_transform)
  32. if download:
  33. self.download()
  34. if not self._check_integrity():
  35. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  36. fp = os.path.join(self.root, self.filename)
  37. data = np.loadtxt(fp)
  38. # convert value to 8 bit unsigned integer
  39. # color (white #255) the pixels
  40. self.data = (data[:, :256] * 255).astype("uint8")
  41. self.data = np.reshape(self.data, (-1, 16, 16))
  42. self.labels = np.nonzero(data[:, 256:])[1]
  43. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  44. """
  45. Args:
  46. index (int): Index
  47. Returns:
  48. tuple: (image, target) where target is index of the target class.
  49. """
  50. img, target = self.data[index], int(self.labels[index])
  51. # doing this so that it is consistent with all other datasets
  52. # to return a PIL Image
  53. img = Image.fromarray(img, mode="L")
  54. if self.transform is not None:
  55. img = self.transform(img)
  56. if self.target_transform is not None:
  57. target = self.target_transform(target)
  58. return img, target
  59. def __len__(self) -> int:
  60. return len(self.data)
  61. def _check_integrity(self) -> bool:
  62. root = self.root
  63. fpath = os.path.join(root, self.filename)
  64. if not check_integrity(fpath, self.md5_checksum):
  65. return False
  66. return True
  67. def download(self) -> None:
  68. if self._check_integrity():
  69. return
  70. root = self.root
  71. download_url(self.url, root, self.filename, self.md5_checksum)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...