Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ensemble.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  1. import plac
  2. from joblib import load
  3. from sklearn.ensemble import VotingClassifier
  4. from src.utils import (
  5. dump_yaml,
  6. evaluate_model,
  7. print_results,
  8. read_data,
  9. read_yaml,
  10. save_results,
  11. )
  12. @plac.annotations(
  13. data_path=("Path to source data", "option", "i", str),
  14. model_path=("Path to save trained Model", "option", "m", str),
  15. out_path=("Path to save trained Model", "option", "o", str),
  16. )
  17. def main(
  18. data_path="data/features/",
  19. model_path="models/",
  20. out_path="models/ensemble/",
  21. ):
  22. X_train, X_test, y_train, y_test = read_data(data_path)
  23. name = "Ensemble"
  24. params = read_yaml("params.yaml", "ensemble")
  25. cl1 = load(f"{model_path}/logistic/model.gz")
  26. cl2 = load(f"{model_path}/svc/model.gz")
  27. cl3 = load(f"{model_path}/r_forrest/model.gz")
  28. estimators = [("l_regression", cl1), ("l_svc", cl2), ("r_forrest", cl3)]
  29. model = VotingClassifier(estimators, **params)
  30. model.fit(X_train, y_train)
  31. accuracy, c_matrix, fig = evaluate_model(model, X_test, y_test)
  32. print_results(accuracy, c_matrix, name)
  33. save_results(out_path, model, fig)
  34. dump_yaml(
  35. dict(accuracy=accuracy, confusion_matrix=c_matrix),
  36. f"{out_path}metrics.yaml",
  37. )
  38. if __name__ == "__main__":
  39. plac.call(main)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...