Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

prepare_data.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. from . import shared
  2. from .shared import save
  3. def split_train_test(ratio=0.2, random_seed=42):
  4. import pandas as pd
  5. from sklearn.model_selection import train_test_split
  6. df = pd.read_csv(shared.raw_data, encoding='utf-8')
  7. df[shared.col_label] = df[shared.col_tags].fillna('').str.contains('machine-learning')
  8. df_positive = df[df[shared.col_label]]
  9. df_negative = df[df[shared.col_label] != True]
  10. train_df, test_df = train_test_split(df, test_size=ratio, random_state=random_seed, stratify=df[shared.col_label])
  11. train_df.to_csv(shared.train_data, index=False)
  12. test_df.to_csv(shared.test_data, index=False)
  13. from html.parser import HTMLParser
  14. class MLStripper(HTMLParser):
  15. def __init__(self):
  16. super().__init__()
  17. self.reset()
  18. self.strict = False
  19. self.convert_charrefs= True
  20. self.fed = []
  21. def handle_data(self, d):
  22. self.fed.append(d)
  23. def get_data(self):
  24. return ''.join(self.fed)
  25. def text_preprocess(s):
  26. strip = MLStripper()
  27. strip.feed(s.lower())
  28. return strip.get_data()
  29. def vectorize_text():
  30. import pandas as pd
  31. from sklearn.feature_extraction.text import TfidfVectorizer
  32. train_df = pd.read_csv(shared.train_data, encoding='utf-8')
  33. test_df = pd.read_csv(shared.test_data, encoding='utf-8')
  34. def text_col(df):
  35. df[shared.col_text] = (df[shared.col_title].fillna('') + df[shared.col_body].fillna('')).astype('U', copy=False)
  36. text_col(train_df)
  37. text_col(test_df)
  38. import tutorial.prepare_data # Required for proper pickling to work
  39. vectorizer = TfidfVectorizer(encoding='utf-8', preprocessor=tutorial.prepare_data.text_preprocess,
  40. stop_words='english', analyzer='word',
  41. max_features=50000, ngram_range=(1,3))
  42. vectorizer.fit(train_df[shared.col_text])
  43. train_tfidf_mat = vectorizer.transform(train_df[shared.col_text])
  44. test_tfidf_mat = vectorizer.transform(test_df[shared.col_text])
  45. return train_df, test_df, vectorizer, train_tfidf_mat, test_tfidf_mat
  46. def prepare_data():
  47. train_df, test_df, vectorizer, train_tfidf_mat, test_tfidf_mat = vectorize_text()
  48. save(vectorizer, shared.vectorizer_pkl)
  49. train_df[[shared.col_id, shared.col_text]].to_csv(shared.train_processed, index=False)
  50. test_df[[shared.col_id, shared.col_text]].to_csv(shared.test_processed, index=False)
  51. save(train_tfidf_mat, shared.train_tfidf)
  52. save(test_tfidf_mat, shared.test_tfidf)
  53. from dagshub import dagshub_logger
  54. with dagshub_logger(should_log_metrics=False, hparams_path='vectorizer-params.yml') as logger:
  55. params = {k:v for k,v in vectorizer.get_params().items() if v is None or type(v) in [str,int,float,bool]}
  56. logger.log_hyperparams(params)
  57. def main():
  58. split_train_test()
  59. prepare_data()
  60. if __name__ == "__main__":
  61. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...