Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dedup.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. import argparse
  2. import glob
  3. import json
  4. import numpy as np
  5. import os
  6. import sys
  7. import torch
  8. from PIL import Image
  9. from sklearn.metrics.pairwise import cosine_similarity
  10. from torchvision import transforms
  11. from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
  12. from torchvision.models.feature_extraction import create_feature_extractor
  13. class FeatureExtractor(torch.nn.Module):
  14. def __init__(self):
  15. super(FeatureExtractor, self).__init__()
  16. model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
  17. return_nodes = {'flatten': 'feature_vector'}
  18. self._feature_extractor = create_feature_extractor(model, return_nodes)
  19. self._feature_extractor.eval()
  20. self._preprocess = transforms.Compose([
  21. transforms.Resize(224),
  22. transforms.ToTensor(),
  23. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  24. ])
  25. def forward(self, img):
  26. data = Image.open(img)
  27. data = data.convert('RGB')
  28. x = self._preprocess(data).unsqueeze(0)
  29. with torch.no_grad():
  30. y = self._feature_extractor(x)['feature_vector']
  31. return y.squeeze().numpy()
  32. def is_image(filename):
  33. ext = os.path.splitext(filename)[-1]
  34. return ext.lower() in ('.jpg', '.jpeg', '.png')
  35. def extract_features(images):
  36. extractor = FeatureExtractor()
  37. features = []
  38. for image in images:
  39. features.append(extractor(image))
  40. return np.array(features)
  41. def main():
  42. parser = argparse.ArgumentParser('Creates a list of duplicate images found within a directory and saves it to a file (similarity.json)')
  43. parser.add_argument('input', help='Input directory with images to compare')
  44. parser.add_argument('--threshold', type=float, default=0.95, help='Threshold to determine similarity')
  45. args = parser.parse_args()
  46. threshold = args.threshold
  47. images = [f for f in glob.glob(os.path.join(args.input, '*')) if is_image(f)]
  48. img_names = np.array([os.path.split(i)[-1] for i in images])
  49. print(f'Anylizing {len(images)} images...')
  50. print(f' + Extracting features...')
  51. features = extract_features(images)
  52. print(f' + Calculating cosine simularity...')
  53. scores = cosine_similarity(features)
  54. np.fill_diagonal(scores, 0.0)
  55. print(f' + Computing output')
  56. similarity_set = set()
  57. for i, j in enumerate(scores):
  58. pred = (j >= threshold)
  59. same = img_names[pred].tolist()
  60. if len(same):
  61. same.append(img_names[i])
  62. same.sort()
  63. similarity_set.add(';'.join(same))
  64. similarity_list = [s.split(';') for s in sorted(list(similarity_set))]
  65. with open('similarity.json', mode='w') as f:
  66. json.dump(similarity_list, f)
  67. if __name__ == '__main__':
  68. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...