Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 5.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
  1. import datetime
  2. import json
  3. from collections import OrderedDict
  4. from typing import Dict
  5. import numpy as np
  6. import ray
  7. import ray.train.torch # NOQA: F401 (imported but unused)
  8. import typer
  9. from ray.data import Dataset
  10. from ray.train.torch.torch_predictor import TorchPredictor
  11. from sklearn.metrics import precision_recall_fscore_support
  12. from snorkel.slicing import PandasSFApplier, slicing_function
  13. from typing_extensions import Annotated
  14. from madewithml import predict, utils
  15. from madewithml.config import logger
  16. # Initialize Typer CLI app
  17. app = typer.Typer()
  18. def get_overall_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict: # pragma: no cover, eval workload
  19. """Get overall performance metrics.
  20. Args:
  21. y_true (np.ndarray): ground truth labels.
  22. y_pred (np.ndarray): predicted labels.
  23. Returns:
  24. Dict: overall metrics.
  25. """
  26. metrics = precision_recall_fscore_support(y_true, y_pred, average="weighted")
  27. overall_metrics = {
  28. "precision": metrics[0],
  29. "recall": metrics[1],
  30. "f1": metrics[2],
  31. "num_samples": np.float64(len(y_true)),
  32. }
  33. return overall_metrics
  34. def get_per_class_metrics(y_true: np.ndarray, y_pred: np.ndarray, class_to_index: Dict) -> Dict: # pragma: no cover, eval workload
  35. """Get per class performance metrics.
  36. Args:
  37. y_true (np.ndarray): ground truth labels.
  38. y_pred (np.ndarray): predicted labels.
  39. class_to_index (Dict): dictionary mapping class to index.
  40. Returns:
  41. Dict: per class metrics.
  42. """
  43. per_class_metrics = {}
  44. metrics = precision_recall_fscore_support(y_true, y_pred, average=None)
  45. for i, _class in enumerate(class_to_index):
  46. per_class_metrics[_class] = {
  47. "precision": metrics[0][i],
  48. "recall": metrics[1][i],
  49. "f1": metrics[2][i],
  50. "num_samples": np.float64(metrics[3][i]),
  51. }
  52. sorted_per_class_metrics = OrderedDict(sorted(per_class_metrics.items(), key=lambda tag: tag[1]["f1"], reverse=True))
  53. return sorted_per_class_metrics
  54. @slicing_function()
  55. def nlp_llm(x): # pragma: no cover, eval workload
  56. """NLP projects that use LLMs."""
  57. nlp_project = "natural-language-processing" in x.tag
  58. llm_terms = ["transformer", "llm", "bert"]
  59. llm_project = any(s.lower() in x.text.lower() for s in llm_terms)
  60. return nlp_project and llm_project
  61. @slicing_function()
  62. def short_text(x): # pragma: no cover, eval workload
  63. """Projects with short titles and descriptions."""
  64. return len(x.text.split()) < 8 # less than 8 words
  65. def get_slice_metrics(y_true: np.ndarray, y_pred: np.ndarray, ds: Dataset) -> Dict: # pragma: no cover, eval workload
  66. """Get performance metrics for slices.
  67. Args:
  68. y_true (np.ndarray): ground truth labels.
  69. y_pred (np.ndarray): predicted labels.
  70. ds (Dataset): Ray dataset with labels.
  71. Returns:
  72. Dict: performance metrics for slices.
  73. """
  74. slice_metrics = {}
  75. df = ds.to_pandas()
  76. df["text"] = df["title"] + " " + df["description"]
  77. slices = PandasSFApplier([nlp_llm, short_text]).apply(df)
  78. for slice_name in slices.dtype.names:
  79. mask = slices[slice_name].astype(bool)
  80. if sum(mask):
  81. metrics = precision_recall_fscore_support(y_true[mask], y_pred[mask], average="micro")
  82. slice_metrics[slice_name] = {}
  83. slice_metrics[slice_name]["precision"] = metrics[0]
  84. slice_metrics[slice_name]["recall"] = metrics[1]
  85. slice_metrics[slice_name]["f1"] = metrics[2]
  86. slice_metrics[slice_name]["num_samples"] = len(y_true[mask])
  87. return slice_metrics
  88. @app.command()
  89. def evaluate(
  90. run_id: Annotated[str, typer.Option(help="id of the specific run to load from")] = None,
  91. dataset_loc: Annotated[str, typer.Option(help="dataset (with labels) to evaluate on")] = None,
  92. results_fp: Annotated[str, typer.Option(help="location to save evaluation results to")] = None,
  93. ) -> Dict: # pragma: no cover, eval workload
  94. """Evaluate on the holdout dataset.
  95. Args:
  96. run_id (str): id of the specific run to load from. Defaults to None.
  97. dataset_loc (str): dataset (with labels) to evaluate on.
  98. results_fp (str, optional): location to save evaluation results to. Defaults to None.
  99. Returns:
  100. Dict: model's performance metrics on the dataset.
  101. """
  102. # Load
  103. ds = ray.data.read_csv(dataset_loc)
  104. best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
  105. predictor = TorchPredictor.from_checkpoint(best_checkpoint)
  106. # y_true
  107. preprocessor = predictor.get_preprocessor()
  108. preprocessed_ds = preprocessor.transform(ds)
  109. values = preprocessed_ds.select_columns(cols=["targets"]).take_all()
  110. y_true = np.stack([item["targets"] for item in values])
  111. # y_pred
  112. z = predictor.predict(data=ds.to_pandas())["predictions"]
  113. y_pred = np.stack(z).argmax(1)
  114. # Metrics
  115. metrics = {
  116. "timestamp": datetime.datetime.now().strftime("%B %d, %Y %I:%M:%S %p"),
  117. "run_id": run_id,
  118. "overall": get_overall_metrics(y_true=y_true, y_pred=y_pred),
  119. "per_class": get_per_class_metrics(y_true=y_true, y_pred=y_pred, class_to_index=preprocessor.class_to_index),
  120. "slices": get_slice_metrics(y_true=y_true, y_pred=y_pred, ds=ds),
  121. }
  122. logger.info(json.dumps(metrics, indent=2))
  123. if results_fp: # pragma: no cover, saving results
  124. utils.save_dict(d=metrics, path=results_fp)
  125. return metrics
  126. if __name__ == "__main__": # pragma: no cover, checked during evaluation workload
  127. app()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...