Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_def.py 1.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  1. import joblib
  2. import numpy as np
  3. import os
  4. import pandas as pd
  5. from sklearn.linear_model import SGDClassifier
  6. import reddit_utils
  7. class NumCatModel:
  8. def __init__(self, random_state=42):
  9. self.model = SGDClassifier(loss="log", random_state=random_state)
  10. def train(self, chunksize, data_loc, target):
  11. print("Training NumCatModel...")
  12. for i, chunk in enumerate(
  13. pd.read_csv(data_loc, chunksize=chunksize)
  14. ):
  15. print(f"Training on chunk {i+1}...")
  16. df_y = chunk[target]
  17. cols_to_train = reddit_utils.NUM_COL_NAMES + reddit_utils.CAT_COL_NAMES
  18. df_X = chunk[cols_to_train]
  19. self.model.partial_fit(df_X, df_y, classes=np.array([0, 1]))
  20. def save_model(self, logger=None):
  21. os.makedirs(reddit_utils.MODELS_DIR, exist_ok=True)
  22. joblib.dump(self.model, reddit_utils.MODEL_PATH)
  23. # log params
  24. if logger:
  25. logger.log_hyperparams(feature_type="numerical + categorical")
  26. logger.log_hyperparams(model_class=type(self.model).__name__)
  27. logger.log_hyperparams(reddit_utils.prepare_log(self.model.get_params(), "model"))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...