Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fer2013.py 5.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  1. import csv
  2. import pathlib
  3. from typing import Any, Callable, Optional, Tuple, Union
  4. import torch
  5. from PIL import Image
  6. from .utils import check_integrity, verify_str_arg
  7. from .vision import VisionDataset
  8. class FER2013(VisionDataset):
  9. """`FER2013
  10. <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
  11. .. note::
  12. This dataset can return test labels only if ``fer2013.csv`` OR
  13. ``icml_face_data.csv`` are present in ``root/fer2013/``. If only
  14. ``train.csv`` and ``test.csv`` are present, the test labels are set to
  15. ``None``.
  16. Args:
  17. root (str or ``pathlib.Path``): Root directory of dataset where directory
  18. ``root/fer2013`` exists. This directory may contain either
  19. ``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
  20. ``test.csv``. Precendence is given in that order, i.e. if
  21. ``fer2013.csv`` is present then the rest of the files will be
  22. ignored. All these (combinations of) files contain the same data and
  23. are supported for convenience, but only ``fer2013.csv`` and
  24. ``icml_face_data.csv`` are able to return non-None test labels.
  25. split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
  26. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
  27. version. E.g, ``transforms.RandomCrop``
  28. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  29. """
  30. _RESOURCES = {
  31. "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
  32. "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
  33. # The fer2013.csv and icml_face_data.csv files contain both train and
  34. # tests instances, and unlike test.csv they contain the labels for the
  35. # test instances. We give these 2 files precedence over train.csv and
  36. # test.csv. And yes, they both contain the same data, but with different
  37. # column names (note the spaces) and ordering:
  38. # $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
  39. # ==> fer2013.csv <==
  40. # emotion,pixels,Usage
  41. #
  42. # ==> icml_face_data.csv <==
  43. # emotion, Usage, pixels
  44. #
  45. # ==> train.csv <==
  46. # emotion,pixels
  47. #
  48. # ==> test.csv <==
  49. # pixels
  50. "fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
  51. "icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
  52. }
  53. def __init__(
  54. self,
  55. root: Union[str, pathlib.Path],
  56. split: str = "train",
  57. transform: Optional[Callable] = None,
  58. target_transform: Optional[Callable] = None,
  59. ) -> None:
  60. self._split = verify_str_arg(split, "split", ("train", "test"))
  61. super().__init__(root, transform=transform, target_transform=target_transform)
  62. base_folder = pathlib.Path(self.root) / "fer2013"
  63. use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
  64. use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
  65. file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
  66. data_file = base_folder / file_name
  67. if not check_integrity(str(data_file), md5=md5):
  68. raise RuntimeError(
  69. f"{file_name} not found in {base_folder} or corrupted. "
  70. f"You can download it from "
  71. f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
  72. )
  73. pixels_key = " pixels" if use_icml_file else "pixels"
  74. usage_key = " Usage" if use_icml_file else "Usage"
  75. def get_img(row):
  76. return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
  77. def get_label(row):
  78. if use_fer_file or use_icml_file or self._split == "train":
  79. return int(row["emotion"])
  80. else:
  81. return None
  82. with open(data_file, "r", newline="") as file:
  83. rows = (row for row in csv.DictReader(file))
  84. if use_fer_file or use_icml_file:
  85. valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
  86. rows = (row for row in rows if row[usage_key] in valid_keys)
  87. self._samples = [(get_img(row), get_label(row)) for row in rows]
  88. def __len__(self) -> int:
  89. return len(self._samples)
  90. def __getitem__(self, idx: int) -> Tuple[Any, Any]:
  91. image_tensor, target = self._samples[idx]
  92. image = Image.fromarray(image_tensor.numpy())
  93. if self.transform is not None:
  94. image = self.transform(image)
  95. if self.target_transform is not None:
  96. target = self.target_transform(target)
  97. return image, target
  98. def extra_repr(self) -> str:
  99. return f"split={self._split}"
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...