Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

coco.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  1. import os.path
  2. from pathlib import Path
  3. from typing import Any, Callable, List, Optional, Tuple, Union
  4. from PIL import Image
  5. from .vision import VisionDataset
  6. class CocoDetection(VisionDataset):
  7. """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
  8. It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
  9. which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
  12. annFile (string): Path to json annotation file.
  13. transform (callable, optional): A function/transform that takes in a PIL image
  14. and returns a transformed version. E.g, ``transforms.PILToTensor``
  15. target_transform (callable, optional): A function/transform that takes in the
  16. target and transforms it.
  17. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  18. and returns a transformed version.
  19. """
  20. def __init__(
  21. self,
  22. root: Union[str, Path],
  23. annFile: str,
  24. transform: Optional[Callable] = None,
  25. target_transform: Optional[Callable] = None,
  26. transforms: Optional[Callable] = None,
  27. ) -> None:
  28. super().__init__(root, transforms, transform, target_transform)
  29. from pycocotools.coco import COCO
  30. self.coco = COCO(annFile)
  31. self.ids = list(sorted(self.coco.imgs.keys()))
  32. def _load_image(self, id: int) -> Image.Image:
  33. path = self.coco.loadImgs(id)[0]["file_name"]
  34. return Image.open(os.path.join(self.root, path)).convert("RGB")
  35. def _load_target(self, id: int) -> List[Any]:
  36. return self.coco.loadAnns(self.coco.getAnnIds(id))
  37. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  38. if not isinstance(index, int):
  39. raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
  40. id = self.ids[index]
  41. image = self._load_image(id)
  42. target = self._load_target(id)
  43. if self.transforms is not None:
  44. image, target = self.transforms(image, target)
  45. return image, target
  46. def __len__(self) -> int:
  47. return len(self.ids)
  48. class CocoCaptions(CocoDetection):
  49. """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
  50. It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
  51. which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
  52. Args:
  53. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
  54. annFile (string): Path to json annotation file.
  55. transform (callable, optional): A function/transform that takes in a PIL image
  56. and returns a transformed version. E.g, ``transforms.PILToTensor``
  57. target_transform (callable, optional): A function/transform that takes in the
  58. target and transforms it.
  59. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  60. and returns a transformed version.
  61. Example:
  62. .. code:: python
  63. import torchvision.datasets as dset
  64. import torchvision.transforms as transforms
  65. cap = dset.CocoCaptions(root = 'dir where images are',
  66. annFile = 'json annotation file',
  67. transform=transforms.PILToTensor())
  68. print('Number of samples: ', len(cap))
  69. img, target = cap[3] # load 4th sample
  70. print("Image Size: ", img.size())
  71. print(target)
  72. Output: ::
  73. Number of samples: 82783
  74. Image Size: (3L, 427L, 640L)
  75. [u'A plane emitting smoke stream flying over a mountain.',
  76. u'A plane darts across a bright blue sky behind a mountain covered in snow',
  77. u'A plane leaves a contrail above the snowy mountain top.',
  78. u'A mountain that has a plane flying overheard in the distance.',
  79. u'A mountain view with a plume of smoke in the background']
  80. """
  81. def _load_target(self, id: int) -> List[str]:
  82. return [ann["caption"] for ann in super()._load_target(id)]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...