Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

split_train_dev.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2021. Jeffrey J. Nirschl. All rights reserved.
  3. #
  4. # Licensed under the MIT license. See the LICENSE.md file in the project
  5. # root directory for full license information.
  6. #
  7. # Time-stamp: <>
  8. # ======================================================================
  9. import argparse
  10. import os
  11. from pathlib import Path
  12. import pandas as pd
  13. from sklearn.model_selection import StratifiedKFold
  14. from src.data import load_data, load_params
  15. def main(mapfile, output_dir=None):
  16. """Split data into train and dev sets"""
  17. if type(mapfile) is str:
  18. assert (os.path.isfile(mapfile)), FileNotFoundError
  19. # read file
  20. train_df = load_data(mapfile,
  21. sep=",", header=0,
  22. index_col=0)
  23. else:
  24. train_df = mapfile
  25. # set index
  26. train_df.index.name = "index"
  27. # load params
  28. params = load_params()
  29. params_split = params['train_test_split']
  30. params_split["random_seed"] = params["random_seed"]
  31. # get filenames and dependent variables (labels)
  32. train_labels = train_df.pop(params_split["target_class"])
  33. train_files = train_df
  34. # K-fold split into train and dev sets stratified by train_labels
  35. # using random seed for reproducibility
  36. skf = StratifiedKFold(n_splits=params_split['n_split'],
  37. random_state=params_split['random_seed'],
  38. shuffle=params_split['shuffle'])
  39. # create splits
  40. split_df = pd.DataFrame()
  41. for n_fold, (train_idx, test_idx) in enumerate(skf.split(train_files,
  42. train_labels)):
  43. fold_name = f"fold_{n_fold + 1:02d}"
  44. # create intermediate dataframe for each fold
  45. temp_df = pd.DataFrame({"image_id": train_idx,
  46. fold_name: "train"}).set_index("image_id")
  47. temp_df = temp_df.append(pd.DataFrame({"image_id": test_idx,
  48. fold_name: "test"}).set_index("image_id"))
  49. # append first fold to empty dataframe or join cols if n_fold > 0
  50. split_df = split_df.append(temp_df) if n_fold == 0 else split_df.join(temp_df)
  51. # sort by index
  52. split_df = split_df.sort_index()
  53. if output_dir:
  54. assert (os.path.isdir(output_dir)), NotADirectoryError
  55. output_dir = Path(output_dir).resolve()
  56. # save output dataframe with indices for train and dev sets
  57. split_df.to_csv(output_dir.joinpath("split_train_dev.csv"),
  58. na_rep="nan")
  59. else:
  60. return split_df
  61. if __name__ == '__main__':
  62. parser = argparse.ArgumentParser()
  63. parser.add_argument("-tr", "--train", dest="train_path",
  64. required=True, help="Train CSV file")
  65. parser.add_argument("-o", "--out-dir", dest="output_dir",
  66. default=Path("./data/processed ").resolve(),
  67. required=False, help="output directory")
  68. args = parser.parse_args()
  69. # split data into train and dev sets
  70. main(args.train_path, args.output_dir)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...