Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

prepare_data.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  1. from . import shared
  2. from .shared import save
  3. import pandas as pd
  4. def split_train_test(ratio=0.2, random_seed=42):
  5. from sklearn.model_selection import train_test_split
  6. df = pd.read_csv(shared.raw_data, encoding='utf-8')
  7. df[shared.col_label] = df[shared.col_tags].fillna('').str.contains('machine-learning')
  8. train_df, test_df = train_test_split(df, test_size=ratio, random_state=random_seed, stratify=df[shared.col_label])
  9. train_df.to_csv(shared.train_data, index=False)
  10. test_df.to_csv(shared.test_data, index=False)
  11. from html.parser import HTMLParser
  12. class MLStripper(HTMLParser):
  13. def __init__(self):
  14. super().__init__()
  15. self.reset()
  16. self.strict = False
  17. self.convert_charrefs= True
  18. self.fed = []
  19. def handle_data(self, d):
  20. self.fed.append(d)
  21. def get_data(self):
  22. return ''.join(self.fed)
  23. def text_preprocess(s):
  24. strip = MLStripper()
  25. strip.feed(s.lower())
  26. return strip.get_data()
  27. def text_col(df):
  28. return (df[shared.col_title].fillna('') + df[shared.col_body].fillna('')).astype('U', copy=False).apply(text_preprocess)
  29. import re
  30. token_pattern = re.compile(r"(?u)\b\w\w+\b")
  31. # TODO: Better number pattern
  32. number_pattern = re.compile(r"^\d+e?|e?\d+$")
  33. def tokenizer(s):
  34. """
  35. Turns numeric tokens into a single <num> token, to prevent lots of redundant terms for different numbers
  36. """
  37. tokens = token_pattern.findall(s)
  38. return ["<num>" if number_pattern.match(t) else t for t in tokens]
  39. def build_pipeline():
  40. from sklearn.impute import SimpleImputer
  41. from sklearn.preprocessing import StandardScaler, FunctionTransformer
  42. from sklearn.pipeline import make_pipeline
  43. from sklearn.compose import ColumnTransformer
  44. import tutorial.prepare_data
  45. from tutorial import prepare_data # Required for proper pickling of this pipeline
  46. return ColumnTransformer([
  47. ('passthrough', 'passthrough', [shared.col_id, shared.col_label]),
  48. ('num_cols', make_pipeline(SimpleImputer(),StandardScaler()), shared.extra_feature_cols),
  49. ])
  50. def map_dataframe(df, pipeline):
  51. cols = [shared.col_id, shared.col_label] + shared.extra_feature_cols
  52. return pd.DataFrame(pipeline.transform(df).astype(float), columns=cols)
  53. def prepare_data():
  54. pipeline = build_pipeline()
  55. print("Loading train data")
  56. train_df = pd.read_csv(shared.train_data, encoding='utf-8')
  57. print("Loading test data")
  58. test_df = pd.read_csv(shared.test_data, encoding='utf-8')
  59. print("Done")
  60. print("Fitting the pipeline...")
  61. pipeline.fit(train_df)
  62. print("Done")
  63. print("Transforming data")
  64. train_df = map_dataframe(train_df, pipeline)
  65. test_df = map_dataframe(test_df, pipeline)
  66. print("Done")
  67. save(pipeline, shared.pipeline_pkl)
  68. print("Saving training data")
  69. train_df.to_pickle(shared.train_processed)
  70. print("Saving test data")
  71. test_df.to_pickle(shared.test_processed)
  72. print("Done")
  73. from dagshub import dagshub_logger
  74. with dagshub_logger(should_log_metrics=False, hparams_path='pipeline-params.yml') as logger:
  75. params = {k:v for k,v in pipeline.get_params().items() if v is None or type(v) in [str,int,float,bool]}
  76. print('Logging pipeline params:')
  77. print(params)
  78. logger.log_hyperparams(params)
  79. def main():
  80. print("Splitting train and test...")
  81. split_train_test()
  82. print("Done")
  83. print("Preparing data...")
  84. prepare_data()
  85. print("Done")
  86. if __name__ == "__main__":
  87. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...