Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

data.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  1. import hashlib
  2. import json
  3. import os
  4. import pandas as pd
  5. import shutil
  6. import yaml
  7. from sklearn.model_selection import train_test_split
  8. from .coco_converter import COCOAnnotationConverter
  9. from .yolo_converter import YOLOAnnotationConverter
  10. def create_splits(df, train = 0.6, valid=0.3, test=0.1):
  11. total = train + valid + test
  12. train /= total
  13. valid /= total
  14. test /= total
  15. train_df, rest_df = train_test_split(df, test_size=valid + test)
  16. total = valid + test
  17. valid /= total
  18. test /= total
  19. valid_df, test_df = train_test_split(rest_df, test_size=test)
  20. train_df['split'] = 'train'
  21. valid_df['split'] = 'valid'
  22. test_df['split'] = 'test'
  23. return pd.concat([train_df, valid_df, test_df], ignore_index=True)
  24. def create_deterministic_splits(df, train=20, valid=10, test=20):
  25. df['hash'] = df['path'].apply(lambda x: hashlib.md5(x.encode()).digest())
  26. df = df.sort_values(by='hash')
  27. df = df.reset_index(drop=True)
  28. df['split'] = None
  29. train_end = train - 1
  30. valid_begin = train
  31. valid_end = train + valid - 1
  32. test_begin = train + valid
  33. test_end = train + valid + test - 1
  34. df.loc[:train_end, 'split'] = 'train'
  35. df.loc[valid_begin:valid_end, 'split'] = 'valid'
  36. df.loc[test_begin:test_end, 'split'] = 'test'
  37. del df['hash']
  38. return df[df['split'].notnull()]
  39. class DataFunctions():
  40. def __init__(self, annotation_file, yolo_dir, to_name='image', from_name='label', label_type='bbox'):
  41. self.coco_conv = COCOAnnotationConverter(
  42. annotation_file=annotation_file,
  43. to_name=to_name,
  44. from_name=from_name,
  45. label_type=label_type
  46. )
  47. self.yolo_conv = YOLOAnnotationConverter(
  48. dataset_dir=yolo_dir,
  49. classes=self.coco_conv.classes,
  50. to_name=to_name,
  51. from_name=from_name,
  52. label_type=label_type)
  53. def remove_yolo_v8_labels(self):
  54. labels = os.path.join(self.yolo_conv.dataset_dir, 'labels')
  55. shutil.rmtree(labels, ignore_errors=True)
  56. def remove_yolo_v8_dataset(self):
  57. shutil.rmtree(self.yolo_conv.dataset_dir, ignore_errors=True)
  58. if os.path.exists('custom_yolo.yaml'):
  59. os.remove('custom_yolo.yaml')
  60. def create_yolo_v8_dataset_yaml(self, dataset, download=True):
  61. path = os.path.abspath(self.yolo_conv.dataset_dir)
  62. if download:
  63. self.remove_yolo_v8_dataset()
  64. for split in ('train', 'valid', 'test'):
  65. split_ds = dataset[dataset['split'] == split]
  66. target_dir = os.path.join(path, f'images/{split}')
  67. _ = split_ds.all().download_files(target_dir=target_dir, keep_source_prefix=False)
  68. else:
  69. self.remove_yolo_v8_labels()
  70. for dp in dataset.all().get_blob_fields("annotation"):
  71. self.yolo_conv.from_de(dp)
  72. train = 'images/train'
  73. val = 'images/valid'
  74. test = 'images/test'
  75. yaml_dict = {
  76. 'path': path,
  77. 'train': train,
  78. 'val': val,
  79. 'test': test,
  80. 'names': {i: name for i, name in enumerate(self.yolo_conv.classes)}
  81. }
  82. with open("custom_yolo.yaml", "w") as file:
  83. file.write(yaml.dump(yaml_dict))
  84. def create_categories_COCO(self, annotations):
  85. categories = set()
  86. json_annotation = json.loads(annotations.decode())
  87. if 'annotations' in json_annotation:
  88. for annotation in json_annotation["annotations"]:
  89. for result in annotation['result']:
  90. categories.add(result['value'][result['type']][0])
  91. return ', '.join(str(item) for item in categories)
  92. def create_metadata(self, s):
  93. s["valid_datapoint"] = True
  94. s['year'] = 2017
  95. # Add annotations where relevant
  96. if not ('annotation' in s and s['annotation']):
  97. annotation = self.coco_conv.to_de(s)
  98. s['annotation'] = annotation
  99. if 'annotation' in s and s['annotation']:
  100. s['categories'] = self.create_categories_COCO(s["annotation"])
  101. return s
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...