Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stanford_cars.py 5.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  1. import pathlib
  2. from typing import Any, Callable, Optional, Tuple, Union
  3. from .folder import default_loader
  4. from .utils import verify_str_arg
  5. from .vision import VisionDataset
  6. class StanfordCars(VisionDataset):
  7. """Stanford Cars Dataset
  8. The Cars dataset contains 16,185 images of 196 classes of cars. The data is
  9. split into 8,144 training images and 8,041 testing images, where each class
  10. has been split roughly in a 50-50 split
  11. The original URL is https://ai.stanford.edu/~jkrause/cars/car_dataset.html, but it is broken.
  12. Follow the instructions in ``download`` argument to obtain and use the dataset offline.
  13. .. note::
  14. This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
  15. Args:
  16. root (str or ``pathlib.Path``): Root directory of dataset
  17. split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
  18. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  19. and returns a transformed version. E.g, ``transforms.RandomCrop``
  20. target_transform (callable, optional): A function/transform that takes in the
  21. target and transforms it.
  22. download (bool, optional): This parameter exists for backward compatibility but it does not
  23. download the dataset, since the original URL is not available anymore. The dataset
  24. seems to be available on Kaggle so you can try to manually download and configure it using
  25. `these instructions <https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616>`_,
  26. or use an integrated
  27. `dataset on Kaggle <https://github.com/pytorch/vision/issues/7545#issuecomment-2282674373>`_.
  28. In both cases, first download and configure the dataset locally, and use the dataset with
  29. ``"download=False"``.
  30. loader (callable, optional): A function to load an image given its path.
  31. By default, it uses PIL as its image loader, but users could also pass in
  32. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  33. """
  34. def __init__(
  35. self,
  36. root: Union[str, pathlib.Path],
  37. split: str = "train",
  38. transform: Optional[Callable] = None,
  39. target_transform: Optional[Callable] = None,
  40. download: bool = False,
  41. loader: Callable[[str], Any] = default_loader,
  42. ) -> None:
  43. try:
  44. import scipy.io as sio
  45. except ImportError:
  46. raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
  47. super().__init__(root, transform=transform, target_transform=target_transform)
  48. self._split = verify_str_arg(split, "split", ("train", "test"))
  49. self._base_folder = pathlib.Path(root) / "stanford_cars"
  50. devkit = self._base_folder / "devkit"
  51. if self._split == "train":
  52. self._annotations_mat_path = devkit / "cars_train_annos.mat"
  53. self._images_base_path = self._base_folder / "cars_train"
  54. else:
  55. self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
  56. self._images_base_path = self._base_folder / "cars_test"
  57. if download:
  58. self.download()
  59. if not self._check_exists():
  60. raise RuntimeError(
  61. "Dataset not found. Try to manually download following the instructions in "
  62. "https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616."
  63. )
  64. self._samples = [
  65. (
  66. str(self._images_base_path / annotation["fname"]),
  67. annotation["class"] - 1, # Original target mapping starts from 1, hence -1
  68. )
  69. for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
  70. ]
  71. self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
  72. self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
  73. self.loader = loader
  74. def __len__(self) -> int:
  75. return len(self._samples)
  76. def __getitem__(self, idx: int) -> Tuple[Any, Any]:
  77. """Returns pil_image and class_id for given index"""
  78. image_path, target = self._samples[idx]
  79. image = self.loader(image_path)
  80. if self.transform is not None:
  81. image = self.transform(image)
  82. if self.target_transform is not None:
  83. target = self.target_transform(target)
  84. return image, target
  85. def _check_exists(self) -> bool:
  86. if not (self._base_folder / "devkit").is_dir():
  87. return False
  88. return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
  89. def download(self):
  90. raise ValueError(
  91. "The original URL is broken so the StanfordCars dataset is not available for automatic "
  92. "download anymore. You can try to download it manually following "
  93. "https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616, "
  94. "and set download=False to avoid this error."
  95. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...