Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. import torch
  2. import torchvision
  3. import transforms as T
  4. from coco_utils import preview_img_and_target
  5. torch.manual_seed(42)
  6. # Adapt to maskrcnn
  7. def get_transforms(train, preview=False):
  8. t = [
  9. T.FilterAndRemapCocoCategories(categories=[0, 3]),
  10. T.ConvertCocoPolysToMask(),
  11. T.ToTensor(),
  12. ]
  13. if train:
  14. t += [
  15. T.RandomHorizontalFlip(.5),
  16. T.RandomVerticalFlip(.5),
  17. T.ColorJitter((.8, 1.8), (.8, 1.8), .3, .1),
  18. ]
  19. if preview:
  20. t.append(T.PreviewTransforms())
  21. transforms = T.Compose(t)
  22. return transforms
  23. class CocoDetection(torchvision.datasets.CocoDetection):
  24. def __init__(self, img_folder, ann_file, transforms):
  25. super(CocoDetection, self).__init__(img_folder, ann_file)
  26. self._transforms = transforms
  27. def __getitem__(self, idx):
  28. img, target = super(CocoDetection, self).__getitem__(idx)
  29. image_id = self.ids[idx]
  30. target = dict(image_id=image_id, annotations=target)
  31. if self._transforms is not None:
  32. img, target = self._transforms(img, target)
  33. return img, target
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...