Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

app.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  1. from fastapi import FastAPI, HTTPException
  2. from pydantic import BaseModel
  3. import numpy as np
  4. import joblib
  5. import uvicorn
  6. from typing import List
  7. import os
  8. # Initialize FastAPI app
  9. app = FastAPI(
  10. title="Iris Model Prediction API",
  11. description="API for making predictions using the trained Iris model",
  12. version="1.0.0",
  13. )
  14. # Define input data model
  15. class IrisData(BaseModel):
  16. features: List[float]
  17. class Config:
  18. schema_extra = {
  19. "example": {"features": [5.1, 3.5, 1.4, 0.2]} # Example of iris features
  20. }
  21. # Load the model at startup
  22. @app.on_event("startup")
  23. async def load_model():
  24. global model
  25. # Try loading joblib model first (simple and direct)
  26. if os.path.exists("models/iris_model.pkl"):
  27. try:
  28. model = joblib.load("models/iris_model.pkl")
  29. print("Model loaded successfully from: models/iris_model.pkl")
  30. return
  31. except Exception as e:
  32. print(f"Failed to load joblib model: {str(e)}")
  33. # Fallback to MLflow approaches
  34. import mlflow
  35. model_approaches = [
  36. ("models:/IrisRandomForest/Staging", "Model Registry - Staging"),
  37. ("models:/IrisRandomForest/latest", "Model Registry - Latest"),
  38. ("mlruns/models", "Local models directory"),
  39. ("mlruns/0/85352b5f8d474b4f850f206501da8f7b/artifacts/model", "Run artifacts"),
  40. ]
  41. for model_uri, description in model_approaches:
  42. try:
  43. model = mlflow.pyfunc.load_model(model_uri)
  44. print(f"Model loaded successfully from: {description} ({model_uri})")
  45. return
  46. except Exception as e:
  47. print(f"Failed to load model from {description}: {str(e)}")
  48. continue
  49. model = None
  50. print(
  51. "Warning: No model could be loaded. API will return errors for prediction requests."
  52. )
  53. print(
  54. "Please ensure you have trained a model first by running: python simple_train.py"
  55. )
  56. @app.get("/")
  57. async def root():
  58. return {"message": "Welcome to the Iris Model Prediction API"}
  59. @app.post("/predict")
  60. async def predict(data: IrisData):
  61. if model is None:
  62. raise HTTPException(status_code=500, detail="Model not loaded")
  63. try:
  64. # Convert input features to numpy array
  65. features = np.array([data.features])
  66. # Make prediction
  67. prediction = model.predict(features)
  68. # Convert prediction to Python type for JSON serialization
  69. prediction = (
  70. prediction.tolist()[0] if isinstance(prediction, np.ndarray) else prediction
  71. )
  72. return {"prediction": prediction, "features": data.features}
  73. except Exception as e:
  74. raise HTTPException(status_code=500, detail=str(e))
  75. if __name__ == "__main__":
  76. uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...