Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. import keras
  2. import numpy as np
  3. import os
  4. import tensorflow as tf
  5. import tensorflow_hub as hub
  6. from PIL import Image
  7. class EfficientNetFeatureExtractor:
  8. def __init__(self):
  9. self.model = hub.KerasLayer("https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_m/feature_vector/2", trainable=False)
  10. self.preprocess = tf.keras.applications.efficientnet.preprocess_input
  11. self.feature_dims = 1280
  12. def extract(self, image_path):
  13. image = Image.open(image_path).convert('RGB')
  14. image = image.resize((480, 480)) # Resize the image to match EfficientNet's input size
  15. image = tf.keras.preprocessing.image.img_to_array(image)
  16. image = self.preprocess(image)
  17. image = tf.expand_dims(image, axis=0)
  18. embedding = self.model(image)[0]
  19. return embedding
  20. class LAIONAestheticsDataGenerator(keras.utils.Sequence):
  21. def __init__(self, img_files, scores, img_dir, feature_extractor: EfficientNetFeatureExtractor, batch_size=32, shuffle=True):
  22. self.feature_extractor = feature_extractor
  23. self.img_path = img_dir
  24. self.img_files = img_files
  25. self.scores = scores
  26. self.embeddings = {}
  27. self.batch_size = batch_size
  28. self.shuffle = shuffle
  29. self.on_epoch_end()
  30. def __len__(self):
  31. return len(self.img_files)
  32. def __getitem__(self, idx):
  33. if tf.is_tensor(idx):
  34. idx = idx.numpy()
  35. start = idx * self.batch_size
  36. stop = start + self.batch_size
  37. batch_idxs = self.indexes[start:stop]
  38. embeddings, scores = self.__data_generation(batch_idxs)
  39. return embeddings, scores
  40. def __data_generation(self, idxs):
  41. embeddings = np.empty((self.batch_size, self.feature_extractor.feature_dims))
  42. scores = np.empty((self.batch_size))
  43. # Generate data
  44. for i, idx in enumerate(idxs):
  45. # Store sample
  46. embedding = self.embeddings.get(idx, None)
  47. if embedding is None:
  48. img_path = os.path.join(self.img_path, self.img_files[idx])
  49. embedding = self.feature_extractor.extract(img_path)
  50. self.embeddings[idx] = embedding
  51. embeddings[i,] = embedding
  52. scores[i] = self.scores[idx]
  53. return embeddings, scores
  54. def on_epoch_end(self):
  55. 'Updates indexes after each epoch'
  56. self.indexes = np.arange(len(self.img_files))
  57. if self.shuffle == True:
  58. np.random.shuffle(self.indexes)
  59. def train_valid_split(data_dir, train_percent=0.8, limit=None, batch_size=32):
  60. annotations_file = os.path.join(data_dir, 'labels.tsv')
  61. data = []
  62. with open(annotations_file) as f:
  63. for i, row in enumerate(f.readlines()):
  64. if limit is not None and i >= limit:
  65. break
  66. img_name, _, aesthetic_score = row.split('\t')[:3]
  67. data.append((img_name, tf.constant(float(aesthetic_score))))
  68. np.random.shuffle(data)
  69. feature_extractor = EfficientNetFeatureExtractor()
  70. train_size = int(train_percent * len(data))
  71. train_data = data[:train_size]
  72. valid_data = data[train_size:]
  73. train_imgs, train_scores = zip(*train_data)
  74. valid_imgs, valid_scores = zip(*valid_data)
  75. train_generator = LAIONAestheticsDataGenerator(train_imgs,
  76. train_scores,
  77. data_dir,
  78. feature_extractor,
  79. batch_size,
  80. shuffle=True)
  81. valid_generator = LAIONAestheticsDataGenerator(valid_imgs,
  82. valid_scores,
  83. data_dir,
  84. feature_extractor,
  85. batch_size,
  86. shuffle=False)
  87. return train_generator, valid_generator
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...