Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 750 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  1. import os
  2. import src.reddit_utils as r_utils
  3. from src.model_def import NumCatModel
  4. from src.utilities import read_yaml
  5. pre_process = read_yaml("params.yaml", "pre_process")
  6. train = read_yaml("params.yaml", "train")
  7. CHUNK_SIZE = pre_process["chunk_size"]
  8. TARGET_LABEL = pre_process["target_col"]
  9. def load_and_train(random_state=42):
  10. train_data_loc = os.path.join("data/processed", r_utils.TRAIN_DF_PATH)
  11. print("Initializing models...")
  12. model = NumCatModel(train["loss"], random_state=random_state)
  13. model.train(chunksize=CHUNK_SIZE, data_loc=train_data_loc, target=TARGET_LABEL)
  14. print("Saving models locally...")
  15. model.save_model()
  16. if __name__ == "__main__":
  17. load_and_train()
  18. print("Loading and training done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...