Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

predict.py 4.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  1. import json
  2. from typing import Any, Dict, Iterable, List
  3. from urllib.parse import urlparse
  4. import pandas as pd
  5. import ray
  6. import torch
  7. import typer
  8. from numpyencoder import NumpyEncoder
  9. from ray.air import Result
  10. from ray.train.torch import TorchPredictor
  11. from ray.train.torch.torch_checkpoint import TorchCheckpoint
  12. from typing_extensions import Annotated
  13. from madewithml.config import logger, mlflow
  14. # Initialize Typer CLI app
  15. app = typer.Typer()
  16. def decode(indices: Iterable[Any], index_to_class: Dict) -> List:
  17. """Decode indices to labels.
  18. Args:
  19. indices (Iterable[Any]): Iterable (list, array, etc.) with indices.
  20. index_to_class (Dict): mapping between indices and labels.
  21. Returns:
  22. List: list of labels.
  23. """
  24. return [index_to_class[index] for index in indices]
  25. def format_prob(prob: Iterable, index_to_class: Dict) -> Dict:
  26. """Format probabilities to a dictionary mapping class label to probability.
  27. Args:
  28. prob (Iterable): probabilities.
  29. index_to_class (Dict): mapping between indices and labels.
  30. Returns:
  31. Dict: Dictionary mapping class label to probability.
  32. """
  33. d = {}
  34. for i, item in enumerate(prob):
  35. d[index_to_class[i]] = item
  36. return d
  37. def predict_with_proba(
  38. df: pd.DataFrame,
  39. predictor: ray.train.torch.torch_predictor.TorchPredictor,
  40. ) -> List: # pragma: no cover, tested with inference workload
  41. """Predict tags (with probabilities) for input data from a dataframe.
  42. Args:
  43. df (pd.DataFrame): dataframe with input features.
  44. predictor (ray.train.torch.torch_predictor.TorchPredictor): loaded predictor from a checkpoint.
  45. Returns:
  46. List: list of predicted labels.
  47. """
  48. preprocessor = predictor.get_preprocessor()
  49. z = predictor.predict(data=df)["predictions"]
  50. import numpy as np
  51. y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
  52. results = []
  53. for i, prob in enumerate(y_prob):
  54. tag = decode([z[i].argmax()], preprocessor.index_to_class)[0]
  55. results.append({"prediction": tag, "probabilities": format_prob(prob, preprocessor.index_to_class)})
  56. return results
  57. @app.command()
  58. def get_best_run_id(experiment_name: str = "", metric: str = "", mode: str = "") -> str: # pragma: no cover, mlflow logic
  59. """Get the best run_id from an MLflow experiment.
  60. Args:
  61. experiment_name (str): name of the experiment.
  62. metric (str): metric to filter by.
  63. mode (str): direction of metric (ASC/DESC).
  64. Returns:
  65. str: best run id from experiment.
  66. """
  67. sorted_runs = mlflow.search_runs(
  68. experiment_names=[experiment_name],
  69. order_by=[f"metrics.{metric} {mode}"],
  70. )
  71. run_id = sorted_runs.iloc[0].run_id
  72. print(run_id)
  73. return run_id
  74. def get_best_checkpoint(run_id: str) -> TorchCheckpoint: # pragma: no cover, mlflow logic
  75. """Get the best checkpoint from a specific run.
  76. Args:
  77. run_id (str): ID of the run to get the best checkpoint from.
  78. Returns:
  79. TorchCheckpoint: Best checkpoint from the run.
  80. """
  81. artifact_dir = urlparse(mlflow.get_run(run_id).info.artifact_uri).path # get path from mlflow
  82. results = Result.from_path(artifact_dir)
  83. return results.best_checkpoints[0][0]
  84. @app.command()
  85. def predict(
  86. run_id: Annotated[str, typer.Option(help="id of the specific run to load from")] = None,
  87. title: Annotated[str, typer.Option(help="project title")] = None,
  88. description: Annotated[str, typer.Option(help="project description")] = None,
  89. ) -> Dict: # pragma: no cover, tested with inference workload
  90. """Predict the tag for a project given it's title and description.
  91. Args:
  92. run_id (str): id of the specific run to load from. Defaults to None.
  93. title (str, optional): project title. Defaults to "".
  94. description (str, optional): project description. Defaults to "".
  95. Returns:
  96. Dict: prediction results for the input data.
  97. """
  98. # Load components
  99. best_checkpoint = get_best_checkpoint(run_id=run_id)
  100. predictor = TorchPredictor.from_checkpoint(best_checkpoint)
  101. preprocessor = predictor.get_preprocessor()
  102. # Predict
  103. sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
  104. results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class)
  105. logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
  106. return results
  107. if __name__ == "__main__": # pragma: no cover, application
  108. app()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...