Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_def.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  1. import os
  2. from pathlib import Path
  3. import joblib
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.linear_model import SGDClassifier
  7. import src.reddit_utils as r_utils
  8. from src.utilities import dump_yaml
  9. class NumCatModel:
  10. def __init__(self, loss, random_state=42):
  11. self.model = SGDClassifier(loss=loss, random_state=random_state)
  12. def train(self, chunksize, data_loc, target):
  13. print("Training NumCatModel...")
  14. for i, chunk in enumerate(pd.read_csv(data_loc, chunksize=chunksize)):
  15. print(f"Training on chunk {i+1}...")
  16. df_y = chunk[target]
  17. cols_to_train = r_utils.NUM_COL_NAMES + r_utils.CAT_COL_NAMES
  18. df_X = chunk[cols_to_train]
  19. self.model.partial_fit(df_X, df_y, classes=np.array([0, 1]))
  20. y_proba = np.array([])
  21. y_pred = np.array([])
  22. y = np.array([])
  23. print("Calculating training metrics...")
  24. for i, chunk in enumerate(pd.read_csv(data_loc, chunksize=chunksize)):
  25. df_y = chunk[target]
  26. cols_to_train = r_utils.NUM_COL_NAMES + r_utils.CAT_COL_NAMES
  27. df_X = chunk[cols_to_train]
  28. y_proba = np.concatenate((y_pred, self.model.predict_proba(df_X)[:, 1]))
  29. y_pred = np.concatenate((y_pred, self.model.predict(df_X)))
  30. y = np.concatenate((y, chunk[target]))
  31. metrics = r_utils.calculate_metrics(y_pred, y_proba, y)
  32. metrics_path = Path("models/metrics/")
  33. metrics_path.mkdir(parents=True, exist_ok=True)
  34. dump_yaml(metrics, metrics_path / "train.yaml")
  35. def save_model(self):
  36. os.makedirs(r_utils.MODELS_DIR, exist_ok=True)
  37. joblib.dump(self.model, r_utils.MODEL_PATH)
  38. # log params
  39. hyper_params = dict(
  40. feature_type="numerical + categorical",
  41. model_class=type(self.model).__name__,
  42. model=self.model.get_params(),
  43. )
  44. dump_yaml(hyper_params, Path("models/result.yaml"))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...