Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

usps.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Callable, Optional, Tuple, Union
  4. import numpy as np
  5. from PIL import Image
  6. from .utils import download_url
  7. from .vision import VisionDataset
  8. class USPS(VisionDataset):
  9. """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
  10. The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
  11. The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
  12. and make pixel values in ``[0, 255]``.
  13. Args:
  14. root (str or ``pathlib.Path``): Root directory of dataset to store``USPS`` data files.
  15. train (bool, optional): If True, creates dataset from ``usps.bz2``,
  16. otherwise from ``usps.t.bz2``.
  17. transform (callable, optional): A function/transform that takes in a PIL image
  18. and returns a transformed version. E.g, ``transforms.RandomCrop``
  19. target_transform (callable, optional): A function/transform that takes in the
  20. target and transforms it.
  21. download (bool, optional): If true, downloads the dataset from the internet and
  22. puts it in root directory. If dataset is already downloaded, it is not
  23. downloaded again.
  24. """
  25. split_list = {
  26. "train": [
  27. "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
  28. "usps.bz2",
  29. "ec16c51db3855ca6c91edd34d0e9b197",
  30. ],
  31. "test": [
  32. "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
  33. "usps.t.bz2",
  34. "8ea070ee2aca1ac39742fdd1ef5ed118",
  35. ],
  36. }
  37. def __init__(
  38. self,
  39. root: Union[str, Path],
  40. train: bool = True,
  41. transform: Optional[Callable] = None,
  42. target_transform: Optional[Callable] = None,
  43. download: bool = False,
  44. ) -> None:
  45. super().__init__(root, transform=transform, target_transform=target_transform)
  46. split = "train" if train else "test"
  47. url, filename, checksum = self.split_list[split]
  48. full_path = os.path.join(self.root, filename)
  49. if download and not os.path.exists(full_path):
  50. download_url(url, self.root, filename, md5=checksum)
  51. import bz2
  52. with bz2.open(full_path) as fp:
  53. raw_data = [line.decode().split() for line in fp.readlines()]
  54. tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
  55. imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
  56. imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
  57. targets = [int(d[0]) - 1 for d in raw_data]
  58. self.data = imgs
  59. self.targets = targets
  60. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  61. """
  62. Args:
  63. index (int): Index
  64. Returns:
  65. tuple: (image, target) where target is index of the target class.
  66. """
  67. img, target = self.data[index], int(self.targets[index])
  68. # doing this so that it is consistent with all other datasets
  69. # to return a PIL Image
  70. img = Image.fromarray(img, mode="L")
  71. if self.transform is not None:
  72. img = self.transform(img)
  73. if self.target_transform is not None:
  74. target = self.target_transform(target)
  75. return img, target
  76. def __len__(self) -> int:
  77. return len(self.data)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...