Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

simple_train.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  1. #!/usr/bin/env python3
  2. """
  3. Simple script to train and save an iris model for serving
  4. """
  5. import mlflow
  6. import mlflow.sklearn
  7. import joblib
  8. import os
  9. from sklearn.datasets import load_iris
  10. from sklearn.ensemble import RandomForestClassifier
  11. from sklearn.model_selection import train_test_split
  12. from sklearn.metrics import accuracy_score
  13. def train_and_save_model():
  14. # Load data
  15. iris = load_iris()
  16. X, y = iris.data, iris.target
  17. # Split data
  18. X_train, X_test, y_train, y_test = train_test_split(
  19. X, y, test_size=0.2, random_state=42
  20. )
  21. # Train model
  22. model = RandomForestClassifier(n_estimators=100, random_state=42)
  23. model.fit(X_train, y_train)
  24. # Evaluate
  25. y_pred = model.predict(X_test)
  26. accuracy = accuracy_score(y_test, y_pred)
  27. print(f"Model accuracy: {accuracy:.3f}")
  28. # Create models directory
  29. os.makedirs("models", exist_ok=True)
  30. # Save model with joblib for direct loading
  31. joblib.dump(model, "models/iris_model.pkl")
  32. print("Model saved to models/iris_model.pkl")
  33. # Also save with MLflow
  34. with mlflow.start_run():
  35. mlflow.sklearn.log_model(model, "iris_model", registered_model_name="IrisModel")
  36. mlflow.log_metric("accuracy", accuracy)
  37. print("Model logged to MLflow")
  38. return model
  39. if __name__ == "__main__":
  40. train_and_save_model()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...