Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  1. import os
  2. import torch
  3. from minimagen.t5 import t5_encode_text
  4. from minimagen.training import _Rescale
  5. from torch.utils.data import Dataset
  6. from torchvision.transforms import Compose, ToTensor
  7. from PIL import Image
  8. class LAIONAestheticsDataset(Dataset):
  9. def __init__(self, annotations_file, img_dir, encoder_name: str, max_length: int,
  10. side_length: int, img_transform=None, limit=None):
  11. self.img_path = img_dir
  12. self.img_files = []
  13. self.captions = []
  14. with open(annotations_file) as f:
  15. for i, row in enumerate(f.readlines()):
  16. if limit is not None and i >= limit:
  17. break
  18. img_name, caption = row.split('\t')[:2]
  19. self.img_files.append(img_name)
  20. self.captions.append(caption)
  21. if img_transform is None:
  22. self.img_transform = Compose([ToTensor(), _Rescale(side_length)])
  23. else:
  24. self.img_transform = Compose([ToTensor(), _Rescale(side_length), img_transform])
  25. self.encoder_name = encoder_name
  26. self.max_length = max_length
  27. def __len__(self):
  28. return len(self.img_files)
  29. def __getitem__(self, idx):
  30. if torch.is_tensor(idx):
  31. idx = idx.tolist()
  32. img = Image.open(os.path.join(self.img_path, self.img_files[idx]))
  33. if img is None:
  34. return None
  35. elif self.img_transform:
  36. img = self.img_transform(img)
  37. # Have to check None again because `Resize` transform can return None
  38. if img is None:
  39. return None
  40. elif img.shape[0] != 3:
  41. return None
  42. enc, msk = t5_encode_text([self.captions[idx]], self.encoder_name, self.max_length)
  43. return {'image': img, 'encoding': enc, 'mask': msk}
  44. def train_valid_split(args, smalldata=False):
  45. limit = smalldata and 16 or None
  46. annotations_file = os.path.join(args.INPUT_DIRECTORY, 'labels.tsv')
  47. dataset_train_valid = LAIONAestheticsDataset(annotations_file, args.INPUT_DIRECTORY, max_length=args.MAX_NUM_WORDS, encoder_name=args.T5_NAME,
  48. side_length=args.IMG_SIDE_LEN, limit=limit)
  49. # Split into train/valid
  50. train_size = int(args.TRAIN_VALID_FRAC * len(dataset_train_valid))
  51. valid_size = len(dataset_train_valid) - train_size
  52. train_dataset, valid_dataset = torch.utils.data.random_split(dataset_train_valid, [train_size, valid_size])
  53. if args.VALID_NUM is not None:
  54. valid_dataset.indices = valid_dataset.indices[:args.VALID_NUM + 1]
  55. return train_dataset, valid_dataset
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...