1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
- import os
- import torch
- from minimagen.t5 import t5_encode_text
- from minimagen.training import _Rescale
- from torch.utils.data import Dataset
- from torchvision.transforms import Compose, ToTensor
- from PIL import Image
- class LAIONAestheticsDataset(Dataset):
- def __init__(self, annotations_file, img_dir, encoder_name: str, max_length: int,
- side_length: int, img_transform=None, limit=None):
-
- self.img_path = img_dir
- self.img_files = []
- self.captions = []
- with open(annotations_file) as f:
- for i, row in enumerate(f.readlines()):
- if limit is not None and i >= limit:
- break
- img_name, caption = row.split('\t')[:2]
- self.img_files.append(img_name)
- self.captions.append(caption)
- if img_transform is None:
- self.img_transform = Compose([ToTensor(), _Rescale(side_length)])
- else:
- self.img_transform = Compose([ToTensor(), _Rescale(side_length), img_transform])
- self.encoder_name = encoder_name
- self.max_length = max_length
- def __len__(self):
- return len(self.img_files)
- def __getitem__(self, idx):
- if torch.is_tensor(idx):
- idx = idx.tolist()
- img = Image.open(os.path.join(self.img_path, self.img_files[idx]))
- if img is None:
- return None
- elif self.img_transform:
- img = self.img_transform(img)
- # Have to check None again because `Resize` transform can return None
- if img is None:
- return None
- elif img.shape[0] != 3:
- return None
- enc, msk = t5_encode_text([self.captions[idx]], self.encoder_name, self.max_length)
- return {'image': img, 'encoding': enc, 'mask': msk}
- def train_valid_split(args, smalldata=False):
- limit = smalldata and 16 or None
- annotations_file = os.path.join(args.INPUT_DIRECTORY, 'labels.tsv')
- dataset_train_valid = LAIONAestheticsDataset(annotations_file, args.INPUT_DIRECTORY, max_length=args.MAX_NUM_WORDS, encoder_name=args.T5_NAME,
- side_length=args.IMG_SIDE_LEN, limit=limit)
-
- # Split into train/valid
- train_size = int(args.TRAIN_VALID_FRAC * len(dataset_train_valid))
- valid_size = len(dataset_train_valid) - train_size
- train_dataset, valid_dataset = torch.utils.data.random_split(dataset_train_valid, [train_size, valid_size])
- if args.VALID_NUM is not None:
- valid_dataset.indices = valid_dataset.indices[:args.VALID_NUM + 1]
- return train_dataset, valid_dataset
|