Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. import json
  2. import math
  3. import os
  4. import pickle
  5. import sys
  6. import pandas as pd
  7. from sklearn import metrics
  8. from sklearn import tree
  9. from dvclive import Live
  10. from matplotlib import pyplot as plt
  11. def evaluate(model, matrix, split, live, save_path):
  12. """
  13. Dump all evaluation metrics and plots for given datasets.
  14. Args:
  15. model (sklearn.ensemble.RandomForestClassifier): Trained classifier.
  16. matrix (scipy.sparse.csr_matrix): Input matrix.
  17. split (str): Dataset name.
  18. live (dvclive.Live): Dvclive instance.
  19. save_path (str): Path to save the metrics.
  20. """
  21. labels = matrix[:, 1].toarray().astype(int)
  22. x = matrix[:, 2:]
  23. predictions_by_class = model.predict_proba(x)
  24. predictions = predictions_by_class[:, 1]
  25. # Use dvclive to log a few simple metrics...
  26. avg_prec = metrics.average_precision_score(labels, predictions)
  27. roc_auc = metrics.roc_auc_score(labels, predictions)
  28. if not live.summary:
  29. live.summary = {"avg_prec": {}, "roc_auc": {}}
  30. live.summary["avg_prec"][split] = avg_prec
  31. live.summary["roc_auc"][split] = roc_auc
  32. # ... and plots...
  33. # ... like an roc plot...
  34. live.log_sklearn_plot("roc", labels, predictions, name=f"roc/{split}")
  35. # ... and precision recall plot...
  36. # ... which passes `drop_intermediate=True` to the sklearn method...
  37. live.log_sklearn_plot(
  38. "precision_recall",
  39. labels,
  40. predictions,
  41. name=f"prc/{split}",
  42. drop_intermediate=True,
  43. )
  44. # ... and confusion matrix plot
  45. live.log_sklearn_plot(
  46. "confusion_matrix",
  47. labels.squeeze(),
  48. predictions_by_class.argmax(-1),
  49. name=f"cm/{split}",
  50. )
  51. def save_importance_plot(live, model, feature_names):
  52. """
  53. Save feature importance plot.
  54. Args:
  55. live (dvclive.Live): DVCLive instance.
  56. model (sklearn.ensemble.RandomForestClassifier): Trained classifier.
  57. feature_names (list): List of feature names.
  58. """
  59. fig, axes = plt.subplots(dpi=100)
  60. fig.subplots_adjust(bottom=0.2, top=0.95)
  61. axes.set_ylabel("Mean decrease in impurity")
  62. importances = model.feature_importances_
  63. forest_importances = pd.Series(importances, index=feature_names).nlargest(n=30)
  64. forest_importances.plot.bar(ax=axes)
  65. live.log_image("importance.png", fig)
  66. def main():
  67. EVAL_PATH = "eval"
  68. if len(sys.argv) != 3:
  69. sys.stderr.write("Arguments error. Usage:\n")
  70. sys.stderr.write("\tpython evaluate.py model features\n")
  71. sys.exit(1)
  72. model_file = sys.argv[1]
  73. train_file = os.path.join(sys.argv[2], "train.pkl")
  74. test_file = os.path.join(sys.argv[2], "test.pkl")
  75. # Load model and data.
  76. with open(model_file, "rb") as fd:
  77. model = pickle.load(fd)
  78. with open(train_file, "rb") as fd:
  79. train, feature_names = pickle.load(fd)
  80. with open(test_file, "rb") as fd:
  81. test, _ = pickle.load(fd)
  82. # Evaluate train and test datasets.
  83. with Live(EVAL_PATH, dvcyaml=False) as live:
  84. evaluate(model, train, "train", live, save_path=EVAL_PATH)
  85. evaluate(model, test, "test", live, save_path=EVAL_PATH)
  86. # Dump feature importance plot.
  87. save_importance_plot(live, model, feature_names)
  88. if __name__ == "__main__":
  89. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...