Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
  1. import os
  2. import torch
  3. import torchvision.models as models
  4. from PIL import Image
  5. from torch.utils.data import Dataset
  6. class EfficientNetFeatureExtractor:
  7. def __init__(self):
  8. model = models.efficientnet_v2_m(weights=models.EfficientNet_V2_M_Weights.IMAGENET1K_V1)
  9. model = torch.nn.Sequential(*list(model.children())[:-1])
  10. self.model = model
  11. self.preprocess = models.EfficientNet_V2_M_Weights.IMAGENET1K_V1.transforms()
  12. self.feature_dims = 1280
  13. def extract(self, image_path):
  14. with torch.no_grad():
  15. image = Image.open(image_path).convert('RGB')
  16. X = self.preprocess(image)
  17. X = X.unsqueeze(0)
  18. embedding = self.model(X)[0, :, 0, 0]
  19. return embedding
  20. class LAIONAestheticsDataset(Dataset):
  21. def __init__(self, annotations_file, img_dir, feature_extractor: EfficientNetFeatureExtractor, limit=None):
  22. self.feature_extractor = feature_extractor
  23. self.img_path = img_dir
  24. self.img_files = []
  25. self.scores = []
  26. self.embeddings = {}
  27. with open(annotations_file) as f:
  28. for i, row in enumerate(f.readlines()):
  29. if limit is not None and i >= limit:
  30. break
  31. img_name, _, aesthetic_score = row.split('\t')[:3]
  32. self.img_files.append(img_name)
  33. self.scores.append(torch.tensor([float(aesthetic_score)]))
  34. def __len__(self):
  35. return len(self.img_files)
  36. def __getitem__(self, idx):
  37. if torch.is_tensor(idx):
  38. idx = idx.tolist()
  39. score = self.scores[idx]
  40. embedding = self.embeddings.get(idx, None)
  41. if embedding is None:
  42. img_path = os.path.join(self.img_path, self.img_files[idx])
  43. embedding = self.feature_extractor.extract(img_path)
  44. self.embeddings[idx] = embedding
  45. return embedding, score
  46. def train_valid_split(data_dir, train_percent=0.8, limit=None):
  47. feature_extractor = EfficientNetFeatureExtractor()
  48. annotations_file = os.path.join(data_dir, 'labels.tsv')
  49. dataset_train_valid = LAIONAestheticsDataset(annotations_file, data_dir, feature_extractor, limit=limit)
  50. # Split into train/valid
  51. train_size = int(train_percent * len(dataset_train_valid))
  52. valid_size = len(dataset_train_valid) - train_size
  53. train_dataset, valid_dataset = torch.utils.data.random_split(dataset_train_valid, [train_size, valid_size])
  54. return train_dataset, valid_dataset
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...