Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

rendered_sst2.py 3.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  1. from pathlib import Path
  2. from typing import Any, Callable, Optional, Tuple, Union
  3. from .folder import default_loader, make_dataset
  4. from .utils import download_and_extract_archive, verify_str_arg
  5. from .vision import VisionDataset
  6. class RenderedSST2(VisionDataset):
  7. """`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_.
  8. Rendered SST2 is an image classification dataset used to evaluate the models capability on optical
  9. character recognition. This dataset was generated by rendering sentences in the Standford Sentiment
  10. Treebank v2 dataset.
  11. This dataset contains two classes (positive and negative) and is divided in three splits: a train
  12. split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
  13. (444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).
  14. Args:
  15. root (str or ``pathlib.Path``): Root directory of the dataset.
  16. split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
  17. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
  18. and returns a transformed version. E.g, ``transforms.RandomCrop``
  19. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  20. download (bool, optional): If True, downloads the dataset from the internet and
  21. puts it in root directory. If dataset is already downloaded, it is not
  22. downloaded again. Default is False.
  23. loader (callable, optional): A function to load an image given its path.
  24. By default, it uses PIL as its image loader, but users could also pass in
  25. ``torchvision.io.decode_image`` for decoding image data into tensors directly.
  26. """
  27. _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
  28. _MD5 = "2384d08e9dcfa4bd55b324e610496ee5"
  29. def __init__(
  30. self,
  31. root: Union[str, Path],
  32. split: str = "train",
  33. transform: Optional[Callable] = None,
  34. target_transform: Optional[Callable] = None,
  35. download: bool = False,
  36. loader: Callable[[str], Any] = default_loader,
  37. ) -> None:
  38. super().__init__(root, transform=transform, target_transform=target_transform)
  39. self._split = verify_str_arg(split, "split", ("train", "val", "test"))
  40. self._split_to_folder = {"train": "train", "val": "valid", "test": "test"}
  41. self._base_folder = Path(self.root) / "rendered-sst2"
  42. self.classes = ["negative", "positive"]
  43. self.class_to_idx = {"negative": 0, "positive": 1}
  44. if download:
  45. self._download()
  46. if not self._check_exists():
  47. raise RuntimeError("Dataset not found. You can use download=True to download it")
  48. self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",))
  49. self.loader = loader
  50. def __len__(self) -> int:
  51. return len(self._samples)
  52. def __getitem__(self, idx: int) -> Tuple[Any, Any]:
  53. image_file, label = self._samples[idx]
  54. image = self.loader(image_file)
  55. if self.transform:
  56. image = self.transform(image)
  57. if self.target_transform:
  58. label = self.target_transform(label)
  59. return image, label
  60. def extra_repr(self) -> str:
  61. return f"split={self._split}"
  62. def _check_exists(self) -> bool:
  63. for class_label in set(self.classes):
  64. if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir():
  65. return False
  66. return True
  67. def _download(self) -> None:
  68. if self._check_exists():
  69. return
  70. download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...