Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 2.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. import sys
  2. import os
  3. import pickle
  4. import json
  5. import yaml
  6. from sklearn.model_selection import train_test_split
  7. from sklearn.preprocessing import LabelBinarizer
  8. from sklearn.metrics import accuracy_score
  9. import matplotlib.pyplot as plt
  10. import plots
  11. params = yaml.safe_load(open('params.yaml'))['evaluate']
  12. seed = params['seed']
  13. split = params['split']
  14. model_type = params['model']
  15. model_file = os.path.join('model', sys.argv[1]) #model.pkl
  16. input = os.path.join(sys.argv[2], 'data.pkl') #data/prepared
  17. scores_file = os.path.join('reports', 'metrics', '{}_{}'.format(model_type, sys.argv[3])) #scores.json
  18. confusion_matrix_plots_file = os.path.join('reports', 'plots', '{}_{}'.format(model_type, sys.argv[4])) #confusion_matrix.png
  19. roc_auc_plots_file = os.path.join('reports', 'plots', '{}_{}'.format(model_type, sys.argv[5])) #ROC_AUC_curve.png
  20. def split_dataset(df):
  21. X = df[df.columns[:-1]]
  22. y = df['Star type']
  23. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=split, random_state=seed)
  24. return X_train, X_test, y_train, y_test
  25. # Confusion matrix
  26. def plot_confusion_matrix(model, X_test, confusion_matrix_plots_file, y_test):
  27. predictions = model.predict(X_test)
  28. with open(confusion_matrix_plots_file, 'w') as fd:
  29. fig = plots.confusion_matrix_plot(y_test, predictions, [0,1,2,3,4,5])
  30. plt.savefig(confusion_matrix_plots_file)
  31. plt.close()
  32. # ROC AUC Curve
  33. def plot_roc_auc(y_test, roc_auc_plots_file, X_test, model_type, model):
  34. # convert classes to binaries
  35. lb = LabelBinarizer()
  36. y_test = lb.fit_transform(y_test)
  37. with open(roc_auc_plots_file, 'w') as fd:
  38. fig = plots.roc_auc_multiclass(X_test, y_test, model_type, model)
  39. plt.savefig(roc_auc_plots_file)
  40. plt.close()
  41. ## Metrics
  42. # # ROC AUC metric
  43. # def return_roc_auc(X_test, y_test):
  44. # predictions = model.predict_proba(X_test)
  45. # roc_auc = roc_auc_score(y_test, predictions, multi_class='ovr')
  46. # return roc_auc
  47. os.makedirs('reports', exist_ok=True)
  48. os.makedirs('reports/metrics', exist_ok=True)
  49. os.makedirs('reports/plots', exist_ok=True)
  50. with open(model_file, 'rb') as fd:
  51. model = pickle.load(fd)
  52. with open(input, 'rb') as fd:
  53. df = pickle.load(fd)
  54. X_train, X_test, y_train, y_test = split_dataset(df)
  55. with open(scores_file, 'w') as fd:
  56. json.dump({
  57. 'model type': model_type,
  58. 'Accuracy Score': accuracy_score(y_test, model.predict(X_test)),
  59. 'best parameters': model.best_params_
  60. },
  61. fd
  62. )
  63. # Create metric plots
  64. plot_confusion_matrix(model, X_test, confusion_matrix_plots_file, y_test)
  65. plot_roc_auc(y_test, roc_auc_plots_file, X_test, model_type, model)
  66. # python src/evaluate.py model.pkl data/prepared scores.json confusion_matrix.png ROC_AUC_curve.png
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...