Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

split_train_dev.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2021. Jeffrey J. Nirschl. All rights reserved.
  3. #
  4. # Licensed under the MIT license. See the LICENSE.md file in the project
  5. # root directory for full license information.
  6. #
  7. # Time-stamp: <>
  8. # ======================================================================
  9. import argparse
  10. import os
  11. from pathlib import Path
  12. import pandas as pd
  13. from sklearn.model_selection import StratifiedKFold
  14. from src.data import load_data, load_params
  15. def main(train_path, output_dir):
  16. """Split data into train and dev sets"""
  17. assert (os.path.isdir(output_dir)), NotADirectoryError
  18. output_dir = Path(output_dir).resolve()
  19. # read file
  20. train_df = load_data(train_path,
  21. sep=",", header=0,
  22. index_col="PassengerId")
  23. # load params
  24. params = load_params()
  25. params_split = params['train_test_split']
  26. params_split["random_seed"] = params["random_seed"]
  27. # get independent variables (features) and
  28. # dependent variables (labels)
  29. train_feats = train_df.drop(params_split["target_class"], axis=1)
  30. train_labels = train_df[params_split["target_class"]]
  31. # K-fold split into train and dev sets stratified by train_labels
  32. # using random seed for reproducibility
  33. skf = StratifiedKFold(n_splits=params_split['n_split'],
  34. random_state=params_split['random_seed'],
  35. shuffle=params_split['shuffle'])
  36. # create splits
  37. split_df = pd.DataFrame()
  38. for n_fold, (train_idx, test_idx) in enumerate(skf.split(train_feats,
  39. train_labels)):
  40. fold_name = f"fold_{n_fold + 1:02d}"
  41. # create intermediate dataframe for each fold
  42. temp_df = pd.DataFrame({"PassengerId": train_idx,
  43. fold_name: "train"}).set_index("PassengerId")
  44. temp_df = temp_df.append(pd.DataFrame({"PassengerId": test_idx,
  45. fold_name: "test"}).set_index("PassengerId"))
  46. # append first fold to empty dataframe or join cols if n_fold > 0
  47. split_df = split_df.append(temp_df) if n_fold == 0 else split_df.join(temp_df)
  48. # sort by index
  49. split_df = split_df.sort_index()
  50. # save output dataframe with indices for train and dev sets
  51. split_df.to_csv(output_dir.joinpath("split_train_dev.csv"),
  52. na_rep="nan")
  53. if __name__ == '__main__':
  54. parser = argparse.ArgumentParser()
  55. parser.add_argument("-tr", "--train", dest="train_path",
  56. required=True, help="Train CSV file")
  57. parser.add_argument("-o", "--out-dir", dest="output_dir",
  58. default=Path("./data/processed ").resolve(),
  59. required=False, help="output directory")
  60. args = parser.parse_args()
  61. # split data into train and dev sets
  62. main(args.train_path, args.output_dir)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...