Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

training.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  1. import os
  2. import re
  3. import yaml
  4. import dagshub
  5. from model_def import NumCatModel
  6. with open(r"./general_params.yml") as f:
  7. params = yaml.safe_load(f)
  8. CHUNK_SIZE = params["chunk_size"]
  9. TARGET_LABEL = params["target_col"]
  10. train_df_path = "rML-train.csv"
  11. def get_remote_gs_wfs():
  12. print("Retreiving location of remote working file system...")
  13. stream = os.popen("dvc remote list --local")
  14. output = stream.read()
  15. remote_wfs_loc = output.split("\t")[1].split("\n")[0]
  16. return remote_wfs_loc
  17. def load_and_train(remote_wfs, random_state=42):
  18. train_data_loc = os.path.join(remote_wfs, train_df_path)
  19. print("Initializing models...")
  20. model = NumCatModel(random_state=random_state)
  21. model.train(chunksize=CHUNK_SIZE, data_loc=train_data_loc, target=TARGET_LABEL,)
  22. print("Saving models locally...")
  23. with dagshub.dagshub_logger(should_log_metrics=False) as logger:
  24. model.save_model(logger=logger)
  25. if __name__ == "__main__":
  26. remote_wfs = get_remote_gs_wfs()
  27. load_and_train(remote_wfs)
  28. print("Loading and training done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...