Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import joblib
  6. import mlflow
  7. from dotenv import load_dotenv
  8. from sklearn.datasets import load_wine
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.ensemble import GradientBoostingClassifier
  11. from sklearn.metrics import (
  12. classification_report,
  13. confusion_matrix,
  14. ConfusionMatrixDisplay,
  15. accuracy_score
  16. )
  17. from app.utils import load_params
  18. from typing import Tuple, List
  19. # Load environment variables from the .env file
  20. load_dotenv()
  21. # Securely set MLflow environment variables
  22. os.environ["MLFLOW_TRACKING_URI"] = os.getenv("MLFLOW_TRACKING_URI", "")
  23. os.environ["MLFLOW_TRACKING_USERNAME"] = os.getenv("MLFLOW_TRACKING_USERNAME", "")
  24. os.environ["MLFLOW_TRACKING_PASSWORD"] = os.getenv("MLFLOW_TRACKING_PASSWORD", "")
  25. # Load paths from config
  26. params = load_params()
  27. DATA_PATH = params["DATA_PATH"]
  28. MODEL_PATH = params["MODEL_PATH"]
  29. def prepare_data() -> Tuple[pd.DataFrame, pd.Series, List[str]]:
  30. """
  31. Load and prepare the Wine dataset.
  32. Returns:
  33. X (pd.DataFrame): Features
  34. y (pd.Series): Labels
  35. class_labels (List[str]): List of target class names
  36. """
  37. data = load_wine()
  38. X = pd.DataFrame(data.data, columns=data.feature_names)
  39. y = pd.Series(data.target, name="target")
  40. class_labels = list(data.target_names)
  41. df = pd.concat([X, y], axis=1)
  42. os.makedirs(os.path.dirname(DATA_PATH), exist_ok=True)
  43. df.to_csv(DATA_PATH, index=False)
  44. return X, y, class_labels
  45. def train_and_evaluate(X: pd.DataFrame, y: pd.Series, class_labels: List[str]) -> None:
  46. """
  47. Train a Gradient Boosting Classifier and log the model and metrics to MLflow.
  48. Args:
  49. X (pd.DataFrame): Feature dataset
  50. y (pd.Series): Target labels
  51. class_labels (List[str]): Target class names
  52. """
  53. # Split data
  54. X_train, X_test, y_train, y_test = train_test_split(
  55. X, y, test_size=0.2, random_state=42
  56. )
  57. # Initialize model
  58. model = GradientBoostingClassifier(random_state=42)
  59. # Set MLflow tracking and experiment
  60. mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])
  61. mlflow.set_experiment("Wine_Classifier_Experiment_with_docker")
  62. # Start MLflow run
  63. with mlflow.start_run():
  64. # Log model parameters
  65. mlflow.log_param("model_type", "GradientBoostingClassifier")
  66. mlflow.log_param("random_state", 42)
  67. # Train model
  68. model.fit(X_train, y_train)
  69. # Predict
  70. y_pred = model.predict(X_test)
  71. # Log accuracy
  72. accuracy = accuracy_score(y_test, y_pred)
  73. mlflow.log_metric("accuracy", accuracy)
  74. # Log full classification report
  75. report = classification_report(y_test, y_pred, output_dict=True)
  76. for label, metrics in report.items():
  77. if isinstance(metrics, dict):
  78. for metric_name, value in metrics.items():
  79. mlflow.log_metric(f"{label}_{metric_name}", value)
  80. # Generate and log confusion matrix
  81. cm = confusion_matrix(y_test, y_pred)
  82. disp = ConfusionMatrixDisplay(cm, display_labels=class_labels)
  83. disp.plot(cmap="Blues")
  84. plt.title("Confusion Matrix")
  85. plt.savefig("confusion_matrix.png")
  86. mlflow.log_artifact("confusion_matrix.png")
  87. #plt.show()
  88. # Save the trained model
  89. os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
  90. joblib.dump(model, MODEL_PATH)
  91. # Log the model in MLflow
  92. mlflow.sklearn.log_model(model, artifact_path="model")
  93. print(f"Model saved to: {MODEL_PATH}")
  94. print("Model and metrics successfully logged to MLflow.")
  95. def main() -> None:
  96. """Main execution function."""
  97. X, y, class_labels = prepare_data()
  98. train_and_evaluate(X, y, class_labels)
  99. if __name__ == "__main__":
  100. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...