Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

random_forrest.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. import plac
  2. from sklearn.ensemble import RandomForestClassifier
  3. from src.utils import (
  4. dump_yaml,
  5. evaluate_model,
  6. print_results,
  7. read_data,
  8. read_yaml,
  9. save_results,
  10. )
  11. @plac.annotations(
  12. data_path=("Path to source data", "option", "i", str),
  13. n_estimators=("Path to save trained Model", "option", "e", str),
  14. max_samples=("Path to save trained Model", "option", "s", str),
  15. out_path=("Path to save trained Model", "option", "o", str),
  16. )
  17. def main(
  18. data_path="data/features/",
  19. out_path="models/r_forrest/",
  20. n_estimators=10,
  21. max_samples=30,
  22. ):
  23. X_train, X_test, y_train, y_test = read_data(data_path)
  24. name = "RandomForrest"
  25. params = read_yaml("params.yaml", "forrest")
  26. model = RandomForestClassifier(**params)
  27. model.fit(X_train, y_train)
  28. accuracy, c_matrix, fig = evaluate_model(model, X_test, y_test)
  29. print_results(accuracy, c_matrix, name)
  30. save_results(out_path, model, fig)
  31. dump_yaml(
  32. dict(accuracy=accuracy, confusion_matrix=c_matrix),
  33. f"{out_path}metrics.yaml",
  34. )
  35. if __name__ == "__main__":
  36. plac.call(main)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...