Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

setup_dagshub_example.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import os
  2. import time
  3. from git import Repo, GitCommandError
  4. import mlflow
  5. import dagshub
  6. from ray import air, tune
  7. from ray.air import session
  8. from ray.air.integrations.dagshub import setup_dagshub, upload_artifacts
  9. def evaluation_fn(step, width, height):
  10. return (0.1 + width * step / 100) ** (-1) + height * 0.1
  11. def train_function_dagshub(config):
  12. # Hyperparameters
  13. width, height = config["width"], config["height"]
  14. setup_dagshub(config=config)
  15. for step in range(config.get("steps", 100)):
  16. # Iterative training function - can be any arbitrary training procedure
  17. intermediate_score = evaluation_fn(step, width, height)
  18. # Log the metrics to mlflow
  19. mlflow.log_metrics(dict(mean_loss=intermediate_score), step=step)
  20. # Feed the score back to Tune.
  21. session.report({"iterations": step, "mean_loss": intermediate_score})
  22. time.sleep(0.1)
  23. def tune_with_setup(save_artifact=False):
  24. DAGSHUB_REPO = "timho102003/raytest_0607_v12"
  25. dagshub.init(repo_name=DAGSHUB_REPO.split(os.sep)[1], repo_owner=DAGSHUB_REPO.split(os.sep)[0])
  26. # Set the experiment, or create a new one if does not exist yet.
  27. mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", None))
  28. mlflow.set_experiment(experiment_name="mixin_example")
  29. tuner = tune.Tuner(
  30. train_function_dagshub,
  31. tune_config=tune.TuneConfig(
  32. num_samples=5
  33. ),
  34. run_config=air.RunConfig(
  35. name="dagshub",
  36. ),
  37. param_space={
  38. "width": tune.randint(10, 100),
  39. "height": tune.randint(0, 100),
  40. "steps": 5,
  41. "dagshub": {
  42. "experiment_name": "mixin_example",
  43. "dagshub_repository": DAGSHUB_REPO,
  44. "log_mlflow_only": False,
  45. },
  46. },
  47. )
  48. results = tuner.fit()
  49. if save_artifact:
  50. print("Save Artifact!")
  51. upload_artifacts(results=results, repo_name=DAGSHUB_REPO.split(os.sep)[1], repo_owner=DAGSHUB_REPO.split(os.sep)[0])
  52. tune_with_setup(save_artifact=True)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...