Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_def.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  1. import joblib
  2. import numpy as np
  3. import os
  4. import pandas as pd
  5. from sklearn.linear_model import SGDClassifier
  6. import reddit_utils
  7. class NumCatModel:
  8. def __init__(self, random_state=42):
  9. self.model = SGDClassifier(loss="log", random_state=random_state)
  10. def train(self, chunksize, data_loc, target, logger=None):
  11. print("Training NumCatModel...")
  12. for i, chunk in enumerate(pd.read_csv(data_loc, chunksize=chunksize)):
  13. print(f"Training on chunk {i+1}...")
  14. df_y = chunk[target]
  15. cols_to_train = reddit_utils.NUM_COL_NAMES + reddit_utils.CAT_COL_NAMES
  16. df_X = chunk[cols_to_train]
  17. self.model.partial_fit(df_X, df_y, classes=np.array([0, 1]))
  18. if logger != None:
  19. y_proba = np.array([])
  20. y_pred = np.array([])
  21. y = np.array([])
  22. print(f"Calculating training metrics...")
  23. for i, chunk in enumerate(pd.read_csv(data_loc, chunksize=chunksize)):
  24. df_y = chunk[target]
  25. cols_to_train = reddit_utils.NUM_COL_NAMES + reddit_utils.CAT_COL_NAMES
  26. df_X = chunk[cols_to_train]
  27. y_proba = np.concatenate((y_pred, self.model.predict_proba(df_X)[:, 1]))
  28. y_pred = np.concatenate((y_pred, self.model.predict(df_X)))
  29. y = np.concatenate((y, chunk[target]))
  30. metrics = reddit_utils.calculate_metrics(y_pred, y_proba, y)
  31. logger.log_metrics(reddit_utils.prepare_log(metrics, "train"))
  32. def save_model(self, logger=None):
  33. os.makedirs(reddit_utils.MODELS_DIR, exist_ok=True)
  34. joblib.dump(self.model, reddit_utils.MODEL_PATH)
  35. # log params
  36. if logger:
  37. logger.log_hyperparams(feature_type="numerical + categorical")
  38. logger.log_hyperparams(model_class=type(self.model).__name__)
  39. logger.log_hyperparams(
  40. reddit_utils.prepare_log(self.model.get_params(), "model")
  41. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...