1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
- import os.path
- from pathlib import Path
- from typing import Any, Callable, List, Optional, Tuple, Union
- from PIL import Image
- from .vision import VisionDataset
- class CocoDetection(VisionDataset):
- """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
- It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
- which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
- Args:
- root (str or ``pathlib.Path``): Root directory where images are downloaded to.
- annFile (string): Path to json annotation file.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.PILToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- """
- def __init__(
- self,
- root: Union[str, Path],
- annFile: str,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- transforms: Optional[Callable] = None,
- ) -> None:
- super().__init__(root, transforms, transform, target_transform)
- from pycocotools.coco import COCO
- self.coco = COCO(annFile)
- self.ids = list(sorted(self.coco.imgs.keys()))
- def _load_image(self, id: int) -> Image.Image:
- path = self.coco.loadImgs(id)[0]["file_name"]
- return Image.open(os.path.join(self.root, path)).convert("RGB")
- def _load_target(self, id: int) -> List[Any]:
- return self.coco.loadAnns(self.coco.getAnnIds(id))
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- if not isinstance(index, int):
- raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
- id = self.ids[index]
- image = self._load_image(id)
- target = self._load_target(id)
- if self.transforms is not None:
- image, target = self.transforms(image, target)
- return image, target
- def __len__(self) -> int:
- return len(self.ids)
- class CocoCaptions(CocoDetection):
- """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
- It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
- which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
- Args:
- root (str or ``pathlib.Path``): Root directory where images are downloaded to.
- annFile (string): Path to json annotation file.
- transform (callable, optional): A function/transform that takes in a PIL image
- and returns a transformed version. E.g, ``transforms.PILToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- Example:
- .. code:: python
- import torchvision.datasets as dset
- import torchvision.transforms as transforms
- cap = dset.CocoCaptions(root = 'dir where images are',
- annFile = 'json annotation file',
- transform=transforms.PILToTensor())
- print('Number of samples: ', len(cap))
- img, target = cap[3] # load 4th sample
- print("Image Size: ", img.size())
- print(target)
- Output: ::
- Number of samples: 82783
- Image Size: (3L, 427L, 640L)
- [u'A plane emitting smoke stream flying over a mountain.',
- u'A plane darts across a bright blue sky behind a mountain covered in snow',
- u'A plane leaves a contrail above the snowy mountain top.',
- u'A mountain that has a plane flying overheard in the distance.',
- u'A mountain view with a plume of smoke in the background']
- """
- def _load_target(self, id: int) -> List[str]:
- return [ann["caption"] for ann in super()._load_target(id)]
|