Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sweep_slurm.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  1. import argparse
  2. import json
  3. import os
  4. import subprocess
  5. from pathlib import Path
  6. import yaml
  7. import wandb
  8. # Set API key
  9. if Path("keys.json").is_file():
  10. with open("keys.json") as file:
  11. api_key = json.load(file)["wandb_key"]
  12. os.environ["WANDB_API_KEY"] = api_key
  13. # Gather nodes allocated to current slurm job
  14. result = subprocess.run(["scontrol", "show", "hostnames"], stdout=subprocess.PIPE)
  15. node_list = result.stdout.decode("utf-8").split("\n")[:-1]
  16. def main():
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument("sweep_config", type=str)
  19. parser.add_argument("train_script", type=str)
  20. parser.add_argument("project", type=str)
  21. args = parser.parse_args()
  22. wandb.init(project=args.project)
  23. with open(args.sweep_config) as file:
  24. config_dict = yaml.load(file, Loader=yaml.FullLoader)
  25. config_dict["program"] = args.train_script
  26. sweep_id = wandb.sweep(config_dict, project=args.project)
  27. sp = []
  28. for node in node_list:
  29. sp.append(
  30. subprocess.Popen(
  31. [
  32. "srun",
  33. "--nodes=1",
  34. "--ntasks=1",
  35. "-w",
  36. node,
  37. "start-agent.sh",
  38. sweep_id,
  39. args.project,
  40. ]
  41. )
  42. )
  43. exit_codes = [p.wait() for p in sp] # wait for processes to finish
  44. return exit_codes
  45. if __name__ == "__main__":
  46. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...