1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
- import os
- from pathlib import Path
- from typing import Any, Callable, Optional, Tuple, Union
- import numpy as np
- from PIL import Image
- from .utils import download_url
- from .vision import VisionDataset
- class USPS(VisionDataset):
- """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
- The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
- The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
- and make pixel values in ``[0, 255]``.
- Args:
- root (str or ``pathlib.Path``): Root directory of dataset to store``USPS`` data files.
- train (bool, optional): If True, creates dataset from ``usps.bz2``,
- otherwise from ``usps.t.bz2``.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
- """
- split_list = {
- "train": [
- "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
- "usps.bz2",
- "ec16c51db3855ca6c91edd34d0e9b197",
- ],
- "test": [
- "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
- "usps.t.bz2",
- "8ea070ee2aca1ac39742fdd1ef5ed118",
- ],
- }
- def __init__(
- self,
- root: Union[str, Path],
- train: bool = True,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- ) -> None:
- super().__init__(root, transform=transform, target_transform=target_transform)
- split = "train" if train else "test"
- url, filename, checksum = self.split_list[split]
- full_path = os.path.join(self.root, filename)
- if download and not os.path.exists(full_path):
- download_url(url, self.root, filename, md5=checksum)
- import bz2
- with bz2.open(full_path) as fp:
- raw_data = [line.decode().split() for line in fp.readlines()]
- tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
- imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
- imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
- targets = [int(d[0]) - 1 for d in raw_data]
- self.data = imgs
- self.targets = targets
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is index of the target class.
- """
- img, target = self.data[index], int(self.targets[index])
- # doing this so that it is consistent with all other datasets
- # to return a PIL Image
- img = Image.fromarray(img, mode="L")
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def __len__(self) -> int:
- return len(self.data)
|