Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

serve.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  1. import argparse
  2. from http import HTTPStatus
  3. from typing import Dict
  4. import pandas as pd
  5. import ray
  6. from fastapi import FastAPI
  7. from ray import serve
  8. from ray.train.torch import TorchPredictor
  9. from starlette.requests import Request
  10. from madewithml import evaluate, predict
  11. from madewithml.config import MLFLOW_TRACKING_URI, mlflow
  12. # Define application
  13. app = FastAPI(
  14. title="Made With ML",
  15. description="Classify machine learning projects.",
  16. version="0.1",
  17. )
  18. @serve.deployment(route_prefix="/", num_replicas="1", ray_actor_options={"num_cpus": 8, "num_gpus": 0})
  19. @serve.ingress(app)
  20. class ModelDeployment:
  21. def __init__(self, run_id: str, threshold: int = 0.9):
  22. """Initialize the model."""
  23. self.run_id = run_id
  24. self.threshold = threshold
  25. mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) # so workers have access to model registry
  26. best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
  27. self.predictor = TorchPredictor.from_checkpoint(best_checkpoint)
  28. self.preprocessor = self.predictor.get_preprocessor()
  29. @app.get("/")
  30. def _index(self) -> Dict:
  31. """Health check."""
  32. response = {
  33. "message": HTTPStatus.OK.phrase,
  34. "status-code": HTTPStatus.OK,
  35. "data": {},
  36. }
  37. return response
  38. @app.get("/run_id/")
  39. def _run_id(self) -> Dict:
  40. """Get the run ID."""
  41. return {"run_id": self.run_id}
  42. @app.post("/evaluate/")
  43. async def _evaluate(self, request: Request) -> Dict:
  44. data = await request.json()
  45. results = evaluate.evaluate(run_id=self.run_id, dataset_loc=data.get("dataset"))
  46. return {"results": results}
  47. @app.post("/predict/")
  48. async def _predict(self, request: Request) -> Dict:
  49. # Get prediction
  50. data = await request.json()
  51. df = pd.DataFrame([{"title": data.get("title", ""), "description": data.get("description", ""), "tag": ""}])
  52. results = predict.predict_with_proba(df=df, predictor=self.predictor)
  53. # Apply custom logic
  54. for i, result in enumerate(results):
  55. pred = result["prediction"]
  56. prob = result["probabilities"]
  57. if prob[pred] < self.threshold:
  58. results[i]["prediction"] = "other"
  59. return {"results": results}
  60. if __name__ == "__main__":
  61. parser = argparse.ArgumentParser()
  62. parser.add_argument("--run_id", help="run ID to use for serving.")
  63. parser.add_argument("--threshold", type=float, default=0.9, help="threshold for `other` class.")
  64. args = parser.parse_args()
  65. ray.init()
  66. serve.run(ModelDeployment.bind(run_id=args.run_id, threshold=args.threshold))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...