Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

training.py 4.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
  1. import os
  2. import re
  3. import pandas as pd
  4. import yaml
  5. import dagshub
  6. import numpy as np
  7. import joblib
  8. from scipy.sparse.dia import dia_matrix
  9. from sklearn.feature_extraction.text import TfidfVectorizer
  10. from sklearn.linear_model import SGDClassifier
  11. from reddit_utils import calculate_metrics, prepare_log
  12. import reddit_utils
  13. with open(r"./general_params.yml") as f:
  14. params = yaml.safe_load(f)
  15. with open(r"./training_params.yml") as f:
  16. training_params = yaml.safe_load(f)
  17. CHUNK_SIZE = params["chunk_size"]
  18. TARGET_LABEL = params["target_col"]
  19. MODEL_TYPE_TEXT = "model_text"
  20. MODEL_TYPE_NUM_CAT = "model_num_cat"
  21. MODEL_TYPE_OTHER = ""
  22. MODEL_TYPE = (
  23. MODEL_TYPE_TEXT
  24. if training_params["use_text_cols"]
  25. else MODEL_TYPE_NUM_CAT
  26. if training_params["use_number_category_cols"]
  27. else MODEL_TYPE_OTHER
  28. )
  29. train_df_path = "rML-train.csv"
  30. # ----- Helper Functions -----
  31. # A partial fit for the TfidfVectorizer courtesy @maxymoo on Stack Overflow
  32. # https://stackoverflow.com/questions/39109743/adding-new-text-to-sklearn-tfidif-vectorizer-python/39114555#39114555
  33. def partial_fit(self, X):
  34. # If this is the first iteration, use regular fit
  35. if not hasattr(self, "is_initialized"):
  36. self.fit(X)
  37. self.n_docs = len(X)
  38. self.is_initialized = True
  39. else:
  40. max_idx = max(self.vocabulary_.values())
  41. for a in X:
  42. # update vocabulary_
  43. if self.lowercase:
  44. a = str(a).lower()
  45. tokens = re.findall(self.token_pattern, a)
  46. for w in tokens:
  47. if w not in self.vocabulary_:
  48. max_idx += 1
  49. self.vocabulary_[w] = max_idx
  50. # update idf_
  51. df = (self.n_docs + self.smooth_idf) / np.exp(
  52. self.idf_ - 1
  53. ) - self.smooth_idf
  54. self.n_docs += 1
  55. df.resize(len(self.vocabulary_))
  56. for w in tokens:
  57. df[self.vocabulary_[w]] += 1
  58. idf = np.log((self.n_docs + self.smooth_idf) / (df + self.smooth_idf)) + 1
  59. self._tfidf._idf_diag = dia_matrix((idf, 0), shape=(len(idf), len(idf)))
  60. # ----- End Helper Functions -----
  61. class TextModel:
  62. def __init__(self, random_state=42):
  63. self.model = SGDClassifier(loss="log", random_state=random_state)
  64. print("Generate TFIDF features...")
  65. TfidfVectorizer.partial_fit = partial_fit
  66. self.tfidf = TfidfVectorizer(max_features=25000)
  67. def train(self):
  68. print("Training TextModel...")
  69. for i, chunk in enumerate(pd.read_csv(os.path.join(remote_wfs, train_df_path), chunksize=CHUNK_SIZE)):
  70. print(f"Fitting TFIDF to chunk {i+1}...")
  71. self.tfidf.partial_fit(chunk["title_and_body"].values.astype("U"))
  72. print("TFIDF feature matrix created!")
  73. for i, chunk in enumerate(pd.read_csv(os.path.join(remote_wfs, train_df_path), chunksize=CHUNK_SIZE)):
  74. print(f"Training on chunk {i+1}...")
  75. df_y = chunk[TARGET_LABEL]
  76. tfidf_X = self.tfidf.transform(chunk["title_and_body"].values.astype("U"))
  77. self.model.partial_fit(tfidf_X, df_y, classes=np.array([0, 1]))
  78. def save_model(self, logger=None):
  79. os.makedirs(reddit_utils.MODELS_DIR, exist_ok=True)
  80. joblib.dump(self.model, reddit_utils.MODEL_PATH)
  81. joblib.dump(self.tfidf, reddit_utils.TFIDF_PATH)
  82. # log params
  83. if logger:
  84. logger.log_hyperparams(prepare_log(self.tfidf.get_params(), "tfidf"))
  85. logger.log_hyperparams(prepare_log(self.model.get_params(), "model"))
  86. logger.log_hyperparams(model_class=type(self.model).__name__)
  87. def get_remote_gs_wfs():
  88. print("Retreiving location of remote working file system...")
  89. stream = os.popen("dvc remote list --local")
  90. output = stream.read()
  91. remote_wfs_loc = output.split("\t")[1].split("\n")[0]
  92. return remote_wfs_loc
  93. def load_and_train(remote_wfs, model_type=None, random_state=42):
  94. print("Initializing models...")
  95. if model_type == MODEL_TYPE_TEXT:
  96. model = TextModel(random_state=random_state)
  97. else:
  98. # TODO
  99. return
  100. model.train()
  101. print("Saving models locally...")
  102. with dagshub.dagshub_logger(should_log_metrics=False) as logger:
  103. logger.log_hyperparams(feature_type="text")
  104. model.save_model(logger=logger)
  105. if __name__ == "__main__":
  106. remote_wfs = get_remote_gs_wfs()
  107. load_and_train(remote_wfs, MODEL_TYPE)
  108. print("Loading and training done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...