Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

prepare_data.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  1. from . import shared
  2. from .shared import save
  3. import pandas as pd
  4. def split_train_test(ratio=0.2, random_seed=42):
  5. from sklearn.model_selection import train_test_split
  6. df = pd.read_csv(shared.raw_data, encoding='utf-8')
  7. df[shared.col_label] = df[shared.col_tags].fillna('').str.contains('machine-learning')
  8. df_positive = df[df[shared.col_label]]
  9. df_negative = df[df[shared.col_label] != True]
  10. train_df, test_df = train_test_split(df, test_size=ratio, random_state=random_seed, stratify=df[shared.col_label])
  11. train_df.to_csv(shared.train_data, index=False)
  12. test_df.to_csv(shared.test_data, index=False)
  13. from html.parser import HTMLParser
  14. class MLStripper(HTMLParser):
  15. def __init__(self):
  16. super().__init__()
  17. self.reset()
  18. self.strict = False
  19. self.convert_charrefs= True
  20. self.fed = []
  21. def handle_data(self, d):
  22. self.fed.append(d)
  23. def get_data(self):
  24. return ''.join(self.fed)
  25. def text_preprocess(s):
  26. strip = MLStripper()
  27. strip.feed(s.lower())
  28. return strip.get_data()
  29. def text_col(df):
  30. return (df[shared.col_title].fillna('') + df[shared.col_body].fillna('')).astype('U', copy=False).apply(text_preprocess)
  31. def build_pipeline():
  32. from sklearn.impute import SimpleImputer
  33. from sklearn.preprocessing import StandardScaler, FunctionTransformer
  34. from sklearn.feature_extraction.text import TfidfVectorizer
  35. from sklearn.pipeline import make_pipeline
  36. from sklearn.compose import ColumnTransformer
  37. tfidf = TfidfVectorizer(encoding='utf-8', stop_words='english', analyzer='word', max_features=25000, ngram_range=(1, 2))
  38. from tutorial import prepare_data # Required for proper pickling of this pipeline
  39. return ColumnTransformer([
  40. ('passthrough', 'passthrough', [shared.col_id, shared.col_label]),
  41. ('num_cols', make_pipeline(SimpleImputer(),StandardScaler()), shared.extra_feature_cols),
  42. ('tfidf', make_pipeline(FunctionTransformer(prepare_data.text_col), tfidf), shared.text_cols)
  43. ])
  44. def map_dataframe(df, pipeline):
  45. tfidf_cols = [f'Text_{col}' for col in pipeline.named_transformers_.tfidf[1].get_feature_names()]
  46. cols = [shared.col_id, shared.col_label] + shared.extra_feature_cols + tfidf_cols
  47. return pd.DataFrame.sparse.from_spmatrix(pipeline.transform(df), columns=cols)
  48. def prepare_data():
  49. pipeline = build_pipeline()
  50. print("Loading train data")
  51. train_df = pd.read_csv(shared.train_data, encoding='utf-8')
  52. print("Loading test data")
  53. test_df = pd.read_csv(shared.test_data, encoding='utf-8')
  54. print("Done")
  55. print("Fitting the pipeline...")
  56. pipeline.fit(train_df)
  57. print("Done")
  58. print("Transforming data")
  59. train_df = map_dataframe(train_df, pipeline)
  60. test_df = map_dataframe(test_df, pipeline)
  61. print("Done")
  62. save(pipeline, shared.pipeline_pkl)
  63. print("Saving training data")
  64. train_df.to_pickle(shared.train_processed)
  65. print("Saving test data")
  66. test_df.to_pickle(shared.test_processed)
  67. print("Done")
  68. # save(train_df, shared.train_processed)
  69. # save(test_df, shared.test_processed)
  70. from dagshub import dagshub_logger
  71. with dagshub_logger(should_log_metrics=False, hparams_path='pipeline-params.yml') as logger:
  72. params = {k:v for k,v in pipeline.get_params().items() if v is None or type(v) in [str,int,float,bool]}
  73. print('Logging pipeline params:')
  74. print(params)
  75. logger.log_hyperparams(params)
  76. def main():
  77. print("Splitting train and test...")
  78. split_train_test()
  79. print("Done")
  80. print("Preparing data...")
  81. prepare_data()
  82. print("Done")
  83. if __name__ == "__main__":
  84. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...