Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

promote_model.py 5.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  1. #!/usr/bin/env python3
  2. """
  3. Model promotion script for MLflow models.
  4. Promotes models from staging to production based on performance criteria.
  5. """
  6. import os
  7. import json
  8. import mlflow
  9. from mlflow.tracking import MlflowClient
  10. def setup_mlflow_tracking():
  11. """Setup MLflow tracking with proper authentication"""
  12. # Check if we have authentication credentials
  13. dagshub_token = os.getenv("DAGSHUB_TOKEN")
  14. mlflow_username = os.getenv("MLFLOW_TRACKING_USERNAME", "yahiaehab10")
  15. mlflow_password = os.getenv("MLFLOW_TRACKING_PASSWORD", dagshub_token)
  16. if dagshub_token and mlflow_password:
  17. try:
  18. # Set up DagsHub MLflow with authentication
  19. mlflow.set_tracking_uri(
  20. "https://dagshub.com/yahiaehab10/MLFlow_demo.mlflow"
  21. )
  22. # Try to authenticate by creating a simple connection test
  23. client = MlflowClient()
  24. experiments = client.search_experiments(max_results=1)
  25. print("✓ Successfully connected to DagsHub MLflow")
  26. return True
  27. except Exception as e:
  28. print(f"❌ DagsHub authentication failed: {e}")
  29. print("🔄 Falling back to local MLflow tracking")
  30. mlflow.set_tracking_uri("file:./mlruns")
  31. return False
  32. else:
  33. print("⚠️ No DagsHub credentials found, using local MLflow")
  34. mlflow.set_tracking_uri("file:./mlruns")
  35. return False
  36. def get_latest_experiment_metrics():
  37. """Get metrics from the latest experiment run"""
  38. try:
  39. # Try to read metrics from local file first
  40. if os.path.exists("metrics.json"):
  41. with open("metrics.json", "r") as f:
  42. metrics = json.load(f)
  43. print(f"✓ Found metrics: {metrics}")
  44. return metrics
  45. except Exception as e:
  46. print(f"Could not read local metrics: {e}")
  47. # Fallback to MLflow tracking
  48. try:
  49. client = MlflowClient()
  50. experiments = client.search_experiments()
  51. if experiments:
  52. runs = client.search_runs(experiment_ids=[experiments[0].experiment_id])
  53. if runs:
  54. latest_run = runs[0]
  55. metrics = latest_run.data.metrics
  56. print(f"✓ Found MLflow metrics: {metrics}")
  57. return metrics
  58. except Exception as e:
  59. print(f"Could not read MLflow metrics: {e}")
  60. return {}
  61. def promote_model():
  62. """Main model promotion logic"""
  63. print("🚀 Starting model promotion process...")
  64. # Setup connections
  65. dagshub_connected = setup_mlflow_tracking()
  66. # Get performance metrics
  67. metrics = get_latest_experiment_metrics()
  68. if not metrics:
  69. print("❌ No metrics found - cannot promote model")
  70. return False
  71. # Check if model meets promotion criteria
  72. accuracy = metrics.get("accuracy", 0)
  73. precision = metrics.get("precision", 0)
  74. recall = metrics.get("recall", 0)
  75. print(f"📊 Model Performance:")
  76. print(f" Accuracy: {accuracy:.4f}")
  77. print(f" Precision: {precision:.4f}")
  78. print(f" Recall: {recall:.4f}")
  79. # Define promotion thresholds
  80. min_accuracy = 0.85
  81. min_precision = 0.80
  82. min_recall = 0.80
  83. # Check promotion criteria
  84. meets_criteria = (
  85. accuracy >= min_accuracy and precision >= min_precision and recall >= min_recall
  86. )
  87. if meets_criteria:
  88. print("✅ Model meets promotion criteria!")
  89. print("🎯 Model is ready for production deployment")
  90. # Log promotion decision
  91. try:
  92. with mlflow.start_run(run_name="model_promotion"):
  93. mlflow.log_param("promotion_decision", "approved")
  94. mlflow.log_param("promotion_reason", "meets_all_criteria")
  95. mlflow.log_metric("final_accuracy", accuracy)
  96. mlflow.log_metric("final_precision", precision)
  97. mlflow.log_metric("final_recall", recall)
  98. except Exception as e:
  99. print(f"Warning: Could not log to MLflow: {e}")
  100. return True
  101. else:
  102. print("❌ Model does not meet promotion criteria:")
  103. if accuracy < min_accuracy:
  104. print(f" Accuracy {accuracy:.4f} < {min_accuracy}")
  105. if precision < min_precision:
  106. print(f" Precision {precision:.4f} < {min_precision}")
  107. if recall < min_recall:
  108. print(f" Recall {recall:.4f} < {min_recall}")
  109. # Log rejection
  110. try:
  111. with mlflow.start_run(run_name="model_promotion"):
  112. mlflow.log_param("promotion_decision", "rejected")
  113. mlflow.log_param("promotion_reason", "below_thresholds")
  114. mlflow.log_metric("final_accuracy", accuracy)
  115. except Exception as e:
  116. print(f"Warning: Could not log to MLflow: {e}")
  117. return False
  118. if __name__ == "__main__":
  119. success = promote_model()
  120. if success:
  121. print("🎉 Model promotion completed successfully!")
  122. else:
  123. print("⚠️ Model promotion rejected - improve model performance")
  124. exit(1)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...