Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

training.py 865 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  1. import os
  2. import yaml
  3. import dagshub
  4. from model_def import NumCatModel
  5. import reddit_utils
  6. with open(r"./general_params.yml") as f:
  7. params = yaml.safe_load(f)
  8. CHUNK_SIZE = params["chunk_size"]
  9. TARGET_LABEL = params["target_col"]
  10. def load_and_train(random_state=42):
  11. with dagshub.dagshub_logger(metrics_path="training_metrics.csv") as logger:
  12. train_data_loc = os.path.join('processed', reddit_utils.TRAIN_DF_PATH)
  13. print("Initializing models...")
  14. model = NumCatModel(random_state=random_state)
  15. model.train(
  16. chunksize=CHUNK_SIZE,
  17. data_loc=train_data_loc,
  18. target=TARGET_LABEL,
  19. logger=logger,
  20. )
  21. print("Saving models locally...")
  22. model.save_model(logger=logger)
  23. if __name__ == "__main__":
  24. load_and_train()
  25. print("Loading and training done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...