Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

_dataset_wrapper.py 24 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
  1. # type: ignore
  2. from __future__ import annotations
  3. import collections.abc
  4. import contextlib
  5. from collections import defaultdict
  6. from copy import copy
  7. import torch
  8. from torchvision import datasets, tv_tensors
  9. from torchvision.transforms.v2 import functional as F
  10. __all__ = ["wrap_dataset_for_transforms_v2"]
  11. def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
  12. """Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
  13. Example:
  14. >>> dataset = torchvision.datasets.CocoDetection(...)
  15. >>> dataset = wrap_dataset_for_transforms_v2(dataset)
  16. .. note::
  17. For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset
  18. configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
  19. to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so.
  20. The dataset samples are wrapped according to the description below.
  21. Special cases:
  22. * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
  23. returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
  24. ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.tv_tensors``.
  25. The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the
  26. ``"image_id"``, ``"boxes"``, and ``"labels"``.
  27. * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
  28. the target and wrap the data in the corresponding ``torchvision.tv_tensors``. The original keys are
  29. preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
  30. * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
  31. coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
  32. * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
  33. dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
  34. in the corresponding ``torchvision.tv_tensors``. The original keys are preserved. If ``target_keys`` is
  35. omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
  36. * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
  37. :class:`~torchvision.tv_tensors.Mask` tv_tensor.
  38. * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
  39. :class:`~torchvision.tv_tensors.Mask` tv_tensor. The target for ``target_type="instance"`` is *replaced* by
  40. a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.tv_tensors.Mask` tv_tensor) and
  41. ``"labels"``.
  42. * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
  43. coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
  44. Image classification datasets
  45. This wrapper is a no-op for image classification datasets, since they were already fully supported by
  46. :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.
  47. Segmentation datasets
  48. Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
  49. :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
  50. segmentation mask into a :class:`~torchvision.tv_tensors.Mask` (second item).
  51. Video classification datasets
  52. Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
  53. :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
  54. :class:`~torchvision.tv_tensors.Video` while leaving the other items as is.
  55. .. note::
  56. Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
  57. ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.
  58. Args:
  59. dataset: the dataset instance to wrap for compatibility with transforms v2.
  60. target_keys: Target keys to return in case the target is a dictionary. If ``None`` (default), selected keys are
  61. specific to the dataset. If ``"all"``, returns the full target. Can also be a collection of strings for
  62. fine grained access. Currently only supported for :class:`~torchvision.datasets.CocoDetection`,
  63. :class:`~torchvision.datasets.VOCDetection`, :class:`~torchvision.datasets.Kitti`, and
  64. :class:`~torchvision.datasets.WIDERFace`. See above for details.
  65. """
  66. if not (
  67. target_keys is None
  68. or target_keys == "all"
  69. or (isinstance(target_keys, collections.abc.Collection) and all(isinstance(key, str) for key in target_keys))
  70. ):
  71. raise ValueError(
  72. f"`target_keys` can be None, 'all', or a collection of strings denoting the keys to be returned, "
  73. f"but got {target_keys}"
  74. )
  75. # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
  76. # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetTVTensorWrapper (see below) as well as the
  77. # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
  78. # while we can still inject everything that we need.
  79. wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetTVTensorWrapper, type(dataset)), {})
  80. # Since VisionDatasetTVTensorWrapper comes before ImageNet in the MRO, calling the class hits
  81. # VisionDatasetTVTensorWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
  82. # ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
  83. # have the existing instance as attribute on the new object.
  84. return wrapped_dataset_cls(dataset, target_keys)
  85. class WrapperFactories(dict):
  86. def register(self, dataset_cls):
  87. def decorator(wrapper_factory):
  88. self[dataset_cls] = wrapper_factory
  89. return wrapper_factory
  90. return decorator
  91. # We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
  92. # dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
  93. # provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
  94. # we have access to the dataset instance.
  95. WRAPPER_FACTORIES = WrapperFactories()
  96. class VisionDatasetTVTensorWrapper:
  97. def __init__(self, dataset, target_keys):
  98. dataset_cls = type(dataset)
  99. if not isinstance(dataset, datasets.VisionDataset):
  100. raise TypeError(
  101. f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
  102. f"but got a '{dataset_cls.__name__}' instead.\n"
  103. f"For an example of how to perform the wrapping for custom datasets, see\n\n"
  104. "https://pytorch.org/vision/main/auto_examples/plot_tv_tensors.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
  105. )
  106. for cls in dataset_cls.mro():
  107. if cls in WRAPPER_FACTORIES:
  108. wrapper_factory = WRAPPER_FACTORIES[cls]
  109. if target_keys is not None and cls not in {
  110. datasets.CocoDetection,
  111. datasets.VOCDetection,
  112. datasets.Kitti,
  113. datasets.WIDERFace,
  114. }:
  115. raise ValueError(
  116. f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
  117. f"and `WIDERFace`, but got {cls.__name__}."
  118. )
  119. break
  120. elif cls is datasets.VisionDataset:
  121. # TODO: If we have documentation on how to do that, put a link in the error message.
  122. msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
  123. if dataset_cls in datasets.__dict__.values():
  124. msg = (
  125. f"{msg} If an automated wrapper for this dataset would be useful for you, "
  126. f"please open an issue at https://github.com/pytorch/vision/issues."
  127. )
  128. raise TypeError(msg)
  129. self._dataset = dataset
  130. self._target_keys = target_keys
  131. self._wrapper = wrapper_factory(dataset, target_keys)
  132. # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
  133. # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
  134. # `transforms`
  135. # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
  136. # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
  137. # disable all three here to be able to extract the untransformed sample to wrap.
  138. self.transform, dataset.transform = dataset.transform, None
  139. self.target_transform, dataset.target_transform = dataset.target_transform, None
  140. self.transforms, dataset.transforms = dataset.transforms, None
  141. def __getattr__(self, item):
  142. with contextlib.suppress(AttributeError):
  143. return object.__getattribute__(self, item)
  144. return getattr(self._dataset, item)
  145. def __getitem__(self, idx):
  146. # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
  147. # of this class
  148. sample = self._dataset[idx]
  149. sample = self._wrapper(idx, sample)
  150. # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
  151. # or joint (`transforms`), we can access the full functionality through `transforms`
  152. if self.transforms is not None:
  153. sample = self.transforms(*sample)
  154. return sample
  155. def __len__(self):
  156. return len(self._dataset)
  157. # TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
  158. def __reduce__(self):
  159. # __reduce__ gets called when we try to pickle the dataset.
  160. # In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
  161. # We have to reset the [target_]transform[s] attributes of the dataset
  162. # to their original values, because we previously set them to None in __init__().
  163. dataset = copy(self._dataset)
  164. dataset.transform = self.transform
  165. dataset.transforms = self.transforms
  166. dataset.target_transform = self.target_transform
  167. return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)
  168. def raise_not_supported(description):
  169. raise RuntimeError(
  170. f"{description} is currently not supported by this wrapper. "
  171. f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
  172. )
  173. def identity(item):
  174. return item
  175. def identity_wrapper_factory(dataset, target_keys):
  176. def wrapper(idx, sample):
  177. return sample
  178. return wrapper
  179. def pil_image_to_mask(pil_image):
  180. return tv_tensors.Mask(pil_image)
  181. def parse_target_keys(target_keys, *, available, default):
  182. if target_keys is None:
  183. target_keys = default
  184. if target_keys == "all":
  185. target_keys = available
  186. else:
  187. target_keys = set(target_keys)
  188. extra = target_keys - available
  189. if extra:
  190. raise ValueError(f"Target keys {sorted(extra)} are not available")
  191. return target_keys
  192. def list_of_dicts_to_dict_of_lists(list_of_dicts):
  193. dict_of_lists = defaultdict(list)
  194. for dct in list_of_dicts:
  195. for key, value in dct.items():
  196. dict_of_lists[key].append(value)
  197. return dict(dict_of_lists)
  198. def wrap_target_by_type(target, *, target_types, type_wrappers):
  199. if not isinstance(target, (tuple, list)):
  200. target = [target]
  201. wrapped_target = tuple(
  202. type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
  203. )
  204. if len(wrapped_target) == 1:
  205. wrapped_target = wrapped_target[0]
  206. return wrapped_target
  207. def classification_wrapper_factory(dataset, target_keys):
  208. return identity_wrapper_factory(dataset, target_keys)
  209. for dataset_cls in [
  210. datasets.Caltech256,
  211. datasets.CIFAR10,
  212. datasets.CIFAR100,
  213. datasets.ImageNet,
  214. datasets.MNIST,
  215. datasets.FashionMNIST,
  216. datasets.GTSRB,
  217. datasets.DatasetFolder,
  218. datasets.ImageFolder,
  219. datasets.Imagenette,
  220. ]:
  221. WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
  222. def segmentation_wrapper_factory(dataset, target_keys):
  223. def wrapper(idx, sample):
  224. image, mask = sample
  225. return image, pil_image_to_mask(mask)
  226. return wrapper
  227. for dataset_cls in [
  228. datasets.VOCSegmentation,
  229. ]:
  230. WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
  231. def video_classification_wrapper_factory(dataset, target_keys):
  232. if dataset.video_clips.output_format == "THWC":
  233. raise RuntimeError(
  234. f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
  235. f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
  236. )
  237. def wrapper(idx, sample):
  238. video, audio, label = sample
  239. video = tv_tensors.Video(video)
  240. return video, audio, label
  241. return wrapper
  242. for dataset_cls in [
  243. datasets.HMDB51,
  244. datasets.Kinetics,
  245. datasets.UCF101,
  246. ]:
  247. WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)
  248. @WRAPPER_FACTORIES.register(datasets.Caltech101)
  249. def caltech101_wrapper_factory(dataset, target_keys):
  250. if "annotation" in dataset.target_type:
  251. raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
  252. return classification_wrapper_factory(dataset, target_keys)
  253. @WRAPPER_FACTORIES.register(datasets.CocoDetection)
  254. def coco_dectection_wrapper_factory(dataset, target_keys):
  255. target_keys = parse_target_keys(
  256. target_keys,
  257. available={
  258. # native
  259. "segmentation",
  260. "area",
  261. "iscrowd",
  262. "image_id",
  263. "bbox",
  264. "category_id",
  265. # added by the wrapper
  266. "boxes",
  267. "masks",
  268. "labels",
  269. },
  270. default={"image_id", "boxes", "labels"},
  271. )
  272. def segmentation_to_mask(segmentation, *, canvas_size):
  273. from pycocotools import mask
  274. if isinstance(segmentation, dict):
  275. # if counts is a string, it is already an encoded RLE mask
  276. if not isinstance(segmentation["counts"], str):
  277. segmentation = mask.frPyObjects(segmentation, *canvas_size)
  278. elif isinstance(segmentation, list):
  279. segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size))
  280. else:
  281. raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}")
  282. return torch.from_numpy(mask.decode(segmentation))
  283. def wrapper(idx, sample):
  284. image_id = dataset.ids[idx]
  285. image, target = sample
  286. if not target:
  287. return image, dict(image_id=image_id)
  288. canvas_size = tuple(F.get_size(image))
  289. batched_target = list_of_dicts_to_dict_of_lists(target)
  290. target = {}
  291. if "image_id" in target_keys:
  292. target["image_id"] = image_id
  293. if "boxes" in target_keys:
  294. target["boxes"] = F.convert_bounding_box_format(
  295. tv_tensors.BoundingBoxes(
  296. batched_target["bbox"],
  297. format=tv_tensors.BoundingBoxFormat.XYWH,
  298. canvas_size=canvas_size,
  299. ),
  300. new_format=tv_tensors.BoundingBoxFormat.XYXY,
  301. )
  302. if "masks" in target_keys:
  303. target["masks"] = tv_tensors.Mask(
  304. torch.stack(
  305. [
  306. segmentation_to_mask(segmentation, canvas_size=canvas_size)
  307. for segmentation in batched_target["segmentation"]
  308. ]
  309. ),
  310. )
  311. if "labels" in target_keys:
  312. target["labels"] = torch.tensor(batched_target["category_id"])
  313. for target_key in target_keys - {"image_id", "boxes", "masks", "labels"}:
  314. target[target_key] = batched_target[target_key]
  315. return image, target
  316. return wrapper
  317. WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
  318. VOC_DETECTION_CATEGORIES = [
  319. "__background__",
  320. "aeroplane",
  321. "bicycle",
  322. "bird",
  323. "boat",
  324. "bottle",
  325. "bus",
  326. "car",
  327. "cat",
  328. "chair",
  329. "cow",
  330. "diningtable",
  331. "dog",
  332. "horse",
  333. "motorbike",
  334. "person",
  335. "pottedplant",
  336. "sheep",
  337. "sofa",
  338. "train",
  339. "tvmonitor",
  340. ]
  341. VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))
  342. @WRAPPER_FACTORIES.register(datasets.VOCDetection)
  343. def voc_detection_wrapper_factory(dataset, target_keys):
  344. target_keys = parse_target_keys(
  345. target_keys,
  346. available={
  347. # native
  348. "annotation",
  349. # added by the wrapper
  350. "boxes",
  351. "labels",
  352. },
  353. default={"boxes", "labels"},
  354. )
  355. def wrapper(idx, sample):
  356. image, target = sample
  357. batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
  358. if "annotation" not in target_keys:
  359. target = {}
  360. if "boxes" in target_keys:
  361. target["boxes"] = tv_tensors.BoundingBoxes(
  362. [
  363. [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
  364. for bndbox in batched_instances["bndbox"]
  365. ],
  366. format=tv_tensors.BoundingBoxFormat.XYXY,
  367. canvas_size=(image.height, image.width),
  368. )
  369. if "labels" in target_keys:
  370. target["labels"] = torch.tensor(
  371. [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
  372. )
  373. return image, target
  374. return wrapper
  375. @WRAPPER_FACTORIES.register(datasets.SBDataset)
  376. def sbd_wrapper(dataset, target_keys):
  377. if dataset.mode == "boundaries":
  378. raise_not_supported("SBDataset with mode='boundaries'")
  379. return segmentation_wrapper_factory(dataset, target_keys)
  380. @WRAPPER_FACTORIES.register(datasets.CelebA)
  381. def celeba_wrapper_factory(dataset, target_keys):
  382. if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
  383. raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
  384. def wrapper(idx, sample):
  385. image, target = sample
  386. target = wrap_target_by_type(
  387. target,
  388. target_types=dataset.target_type,
  389. type_wrappers={
  390. "bbox": lambda item: F.convert_bounding_box_format(
  391. tv_tensors.BoundingBoxes(
  392. item,
  393. format=tv_tensors.BoundingBoxFormat.XYWH,
  394. canvas_size=(image.height, image.width),
  395. ),
  396. new_format=tv_tensors.BoundingBoxFormat.XYXY,
  397. ),
  398. },
  399. )
  400. return image, target
  401. return wrapper
  402. KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
  403. KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))
  404. @WRAPPER_FACTORIES.register(datasets.Kitti)
  405. def kitti_wrapper_factory(dataset, target_keys):
  406. target_keys = parse_target_keys(
  407. target_keys,
  408. available={
  409. # native
  410. "type",
  411. "truncated",
  412. "occluded",
  413. "alpha",
  414. "bbox",
  415. "dimensions",
  416. "location",
  417. "rotation_y",
  418. # added by the wrapper
  419. "boxes",
  420. "labels",
  421. },
  422. default={"boxes", "labels"},
  423. )
  424. def wrapper(idx, sample):
  425. image, target = sample
  426. if target is None:
  427. return image, target
  428. batched_target = list_of_dicts_to_dict_of_lists(target)
  429. target = {}
  430. if "boxes" in target_keys:
  431. target["boxes"] = tv_tensors.BoundingBoxes(
  432. batched_target["bbox"],
  433. format=tv_tensors.BoundingBoxFormat.XYXY,
  434. canvas_size=(image.height, image.width),
  435. )
  436. if "labels" in target_keys:
  437. target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in batched_target["type"]])
  438. for target_key in target_keys - {"boxes", "labels"}:
  439. target[target_key] = batched_target[target_key]
  440. return image, target
  441. return wrapper
  442. @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
  443. def oxford_iiit_pet_wrapper_factor(dataset, target_keys):
  444. def wrapper(idx, sample):
  445. image, target = sample
  446. if target is not None:
  447. target = wrap_target_by_type(
  448. target,
  449. target_types=dataset._target_types,
  450. type_wrappers={
  451. "segmentation": pil_image_to_mask,
  452. },
  453. )
  454. return image, target
  455. return wrapper
  456. @WRAPPER_FACTORIES.register(datasets.Cityscapes)
  457. def cityscapes_wrapper_factory(dataset, target_keys):
  458. if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
  459. raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
  460. def instance_segmentation_wrapper(mask):
  461. # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
  462. data = pil_image_to_mask(mask)
  463. masks = []
  464. labels = []
  465. for id in data.unique():
  466. masks.append(data == id)
  467. label = id
  468. if label >= 1_000:
  469. label //= 1_000
  470. labels.append(label)
  471. return dict(masks=tv_tensors.Mask(torch.stack(masks)), labels=torch.stack(labels))
  472. def wrapper(idx, sample):
  473. image, target = sample
  474. target = wrap_target_by_type(
  475. target,
  476. target_types=dataset.target_type,
  477. type_wrappers={
  478. "instance": instance_segmentation_wrapper,
  479. "semantic": pil_image_to_mask,
  480. },
  481. )
  482. return image, target
  483. return wrapper
  484. @WRAPPER_FACTORIES.register(datasets.WIDERFace)
  485. def widerface_wrapper(dataset, target_keys):
  486. target_keys = parse_target_keys(
  487. target_keys,
  488. available={
  489. "bbox",
  490. "blur",
  491. "expression",
  492. "illumination",
  493. "occlusion",
  494. "pose",
  495. "invalid",
  496. },
  497. default="all",
  498. )
  499. def wrapper(idx, sample):
  500. image, target = sample
  501. if target is None:
  502. return image, target
  503. target = {key: target[key] for key in target_keys}
  504. if "bbox" in target_keys:
  505. target["bbox"] = F.convert_bounding_box_format(
  506. tv_tensors.BoundingBoxes(
  507. target["bbox"], format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
  508. ),
  509. new_format=tv_tensors.BoundingBoxFormat.XYXY,
  510. )
  511. return image, target
  512. return wrapper
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...