Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

setup_mlflow_example.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. import os
  2. import tempfile
  3. import time
  4. import mlflow
  5. from ray import air, tune
  6. from ray.air import session
  7. from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow
  8. def evaluation_fn(step, width, height):
  9. return (0.1 + width * step / 100) ** (-1) + height * 0.1
  10. def train_function(config):
  11. width, height = config["width"], config["height"]
  12. for step in range(config.get("steps", 100)):
  13. # Iterative training function - can be any arbitrary training procedure
  14. intermediate_score = evaluation_fn(step, width, height)
  15. # Feed the score back to Tune.
  16. session.report({"iterations": step, "mean_loss": intermediate_score})
  17. time.sleep(0.1)
  18. def train_function_mlflow(config):
  19. setup_mlflow(config)
  20. # Hyperparameters
  21. width, height = config["width"], config["height"]
  22. for step in range(config.get("steps", 100)):
  23. # Iterative training function - can be any arbitrary training procedure
  24. intermediate_score = evaluation_fn(step, width, height)
  25. # Log the metrics to mlflow
  26. mlflow.log_metrics(dict(mean_loss=intermediate_score), step=step)
  27. # Feed the score back to Tune.
  28. session.report({"iterations": step, "mean_loss": intermediate_score})
  29. time.sleep(0.1)
  30. def tune_with_setup(mlflow_tracking_uri, finish_fast=False):
  31. # Set the experiment, or create a new one if does not exist yet.
  32. mlflow.set_tracking_uri(mlflow_tracking_uri)
  33. mlflow.set_experiment(experiment_name="mixin_example")
  34. tuner = tune.Tuner(
  35. train_function_mlflow,
  36. tune_config=tune.TuneConfig(
  37. num_samples=5
  38. ),
  39. run_config=air.RunConfig(
  40. name="mlflow",
  41. ),
  42. param_space={
  43. "width": tune.randint(10, 100),
  44. "height": tune.randint(0, 100),
  45. "steps": 5 if finish_fast else 100,
  46. "mlflow": {
  47. "experiment_name": "mixin_example",
  48. "tracking_uri": mlflow.get_tracking_uri(),
  49. },
  50. },
  51. )
  52. results = tuner.fit()
  53. smoke_test = True
  54. mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns")
  55. tune_with_setup(mlflow_tracking_uri, finish_fast=smoke_test)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...