Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

linear_svc.py 895 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  1. import plac
  2. from sklearn.svm import LinearSVC
  3. from src.utils import (
  4. dump_yaml,
  5. evaluate_model,
  6. print_results,
  7. read_data,
  8. read_yaml,
  9. save_results,
  10. )
  11. @plac.annotations(
  12. data_path=("Path to source data", "option", "i", str),
  13. out_path=("Path to save trained Model", "option", "o", str),
  14. )
  15. def main(data_path="data/features/", out_path="models/svc/"):
  16. X_train, X_test, y_train, y_test = read_data(data_path)
  17. name = "LinearSVC"
  18. params = read_yaml("params.yaml", "svc")
  19. model = LinearSVC(**params)
  20. model.fit(X_train, y_train)
  21. accuracy, c_matrix, fig = evaluate_model(model, X_test, y_test)
  22. print_results(accuracy, c_matrix, name)
  23. save_results(out_path, model, fig)
  24. dump_yaml(
  25. dict(accuracy=accuracy, confusion_matrix=c_matrix),
  26. f"{out_path}metrics.yaml",
  27. )
  28. if __name__ == "__main__":
  29. plac.call(main)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...