Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tune.py 6.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
  1. import datetime
  2. import json
  3. import ray
  4. import typer
  5. from ray import tune
  6. from ray.air.config import (
  7. CheckpointConfig,
  8. DatasetConfig,
  9. RunConfig,
  10. ScalingConfig,
  11. )
  12. from ray.air.integrations.mlflow import MLflowLoggerCallback
  13. from ray.train.torch import TorchTrainer
  14. from ray.tune import Tuner
  15. from ray.tune.schedulers import AsyncHyperBandScheduler
  16. from ray.tune.search import ConcurrencyLimiter
  17. from ray.tune.search.hyperopt import HyperOptSearch
  18. from typing_extensions import Annotated
  19. from madewithml import data, train, utils
  20. from madewithml.config import MLFLOW_TRACKING_URI, logger
  21. # Initialize Typer CLI app
  22. app = typer.Typer()
  23. @app.command()
  24. def tune_models(
  25. experiment_name: Annotated[str, typer.Option(help="name of the experiment for this training workload.")] = None,
  26. dataset_loc: Annotated[str, typer.Option(help="location of the dataset.")] = None,
  27. initial_params: Annotated[str, typer.Option(help="initial config for the tuning workload.")] = None,
  28. num_workers: Annotated[int, typer.Option(help="number of workers to use for training.")] = 1,
  29. cpu_per_worker: Annotated[int, typer.Option(help="number of CPUs to use per worker.")] = 1,
  30. gpu_per_worker: Annotated[int, typer.Option(help="number of GPUs to use per worker.")] = 0,
  31. num_runs: Annotated[int, typer.Option(help="number of runs in this tuning experiment.")] = 1,
  32. num_samples: Annotated[int, typer.Option(help="number of samples to use from dataset.")] = None,
  33. num_epochs: Annotated[int, typer.Option(help="number of epochs to train for.")] = 1,
  34. batch_size: Annotated[int, typer.Option(help="number of samples per batch.")] = 256,
  35. results_fp: Annotated[str, typer.Option(help="filepath to save results to.")] = None,
  36. ) -> ray.tune.result_grid.ResultGrid:
  37. """Hyperparameter tuning experiment.
  38. Args:
  39. experiment_name (str): name of the experiment for this training workload.
  40. dataset_loc (str): location of the dataset.
  41. initial_params (str): initial config for the tuning workload.
  42. num_workers (int, optional): number of workers to use for training. Defaults to 1.
  43. cpu_per_worker (int, optional): number of CPUs to use per worker. Defaults to 1.
  44. gpu_per_worker (int, optional): number of GPUs to use per worker. Defaults to 0.
  45. num_runs (int, optional): number of runs in this tuning experiment. Defaults to 1.
  46. num_samples (int, optional): number of samples to use from dataset.
  47. If this is passed in, it will override the config. Defaults to None.
  48. num_epochs (int, optional): number of epochs to train for.
  49. If this is passed in, it will override the config. Defaults to None.
  50. batch_size (int, optional): number of samples per batch.
  51. If this is passed in, it will override the config. Defaults to None.
  52. results_fp (str, optional): filepath to save the tuning results. Defaults to None.
  53. Returns:
  54. ray.tune.result_grid.ResultGrid: results of the tuning experiment.
  55. """
  56. # Set up
  57. utils.set_seeds()
  58. train_loop_config = {}
  59. train_loop_config["num_samples"] = num_samples
  60. train_loop_config["num_epochs"] = num_epochs
  61. train_loop_config["batch_size"] = batch_size
  62. # Scaling config
  63. scaling_config = ScalingConfig(
  64. num_workers=num_workers,
  65. use_gpu=bool(gpu_per_worker),
  66. resources_per_worker={"CPU": cpu_per_worker, "GPU": gpu_per_worker},
  67. _max_cpu_fraction_per_node=0.8,
  68. )
  69. # Dataset
  70. ds = data.load_data(dataset_loc=dataset_loc, num_samples=train_loop_config.get("num_samples", None))
  71. train_ds, val_ds = data.stratify_split(ds, stratify="tag", test_size=0.2)
  72. tags = train_ds.unique(column="tag")
  73. train_loop_config["num_classes"] = len(tags)
  74. # Dataset config
  75. dataset_config = {
  76. "train": DatasetConfig(fit=False, transform=False, randomize_block_order=False),
  77. "val": DatasetConfig(fit=False, transform=False, randomize_block_order=False),
  78. }
  79. # Preprocess
  80. preprocessor = data.CustomPreprocessor()
  81. train_ds = preprocessor.fit_transform(train_ds)
  82. val_ds = preprocessor.transform(val_ds)
  83. train_ds = train_ds.materialize()
  84. val_ds = val_ds.materialize()
  85. # Trainer
  86. trainer = TorchTrainer(
  87. train_loop_per_worker=train.train_loop_per_worker,
  88. train_loop_config=train_loop_config,
  89. scaling_config=scaling_config,
  90. datasets={"train": train_ds, "val": val_ds},
  91. dataset_config=dataset_config,
  92. preprocessor=preprocessor,
  93. )
  94. # Checkpoint configuration
  95. checkpoint_config = CheckpointConfig(
  96. num_to_keep=1,
  97. checkpoint_score_attribute="val_loss",
  98. checkpoint_score_order="min",
  99. )
  100. # Run configuration
  101. mlflow_callback = MLflowLoggerCallback(
  102. tracking_uri=MLFLOW_TRACKING_URI,
  103. experiment_name=experiment_name,
  104. save_artifact=True,
  105. )
  106. run_config = RunConfig(
  107. callbacks=[mlflow_callback],
  108. checkpoint_config=checkpoint_config,
  109. )
  110. # Hyperparameters to start with
  111. initial_params = json.loads(initial_params)
  112. search_alg = HyperOptSearch(points_to_evaluate=initial_params)
  113. search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2) # trade off b/w optimization and search space
  114. # Parameter space
  115. param_space = {
  116. "train_loop_config": {
  117. "dropout_p": tune.uniform(0.3, 0.9),
  118. "lr": tune.loguniform(1e-5, 5e-4),
  119. "lr_factor": tune.uniform(0.1, 0.9),
  120. "lr_patience": tune.uniform(1, 10),
  121. }
  122. }
  123. # Scheduler
  124. scheduler = AsyncHyperBandScheduler(
  125. max_t=train_loop_config["num_epochs"], # max epoch (<time_attr>) per trial
  126. grace_period=1, # min epoch (<time_attr>) per trial
  127. )
  128. # Tune config
  129. tune_config = tune.TuneConfig(
  130. metric="val_loss",
  131. mode="min",
  132. search_alg=search_alg,
  133. scheduler=scheduler,
  134. num_samples=num_runs,
  135. )
  136. # Tuner
  137. tuner = Tuner(
  138. trainable=trainer,
  139. run_config=run_config,
  140. param_space=param_space,
  141. tune_config=tune_config,
  142. )
  143. # Tune
  144. results = tuner.fit()
  145. best_trial = results.get_best_result(metric="val_loss", mode="min")
  146. d = {
  147. "timestamp": datetime.datetime.now().strftime("%B %d, %Y %I:%M:%S %p"),
  148. "run_id": utils.get_run_id(experiment_name=experiment_name, trial_id=best_trial.metrics["trial_id"]),
  149. "params": best_trial.config["train_loop_config"],
  150. "metrics": utils.dict_to_list(best_trial.metrics_dataframe.to_dict(), keys=["epoch", "train_loss", "val_loss"]),
  151. }
  152. logger.info(json.dumps(d, indent=2))
  153. if results_fp: # pragma: no cover, saving results
  154. utils.save_dict(d, results_fp)
  155. return results
  156. if __name__ == "__main__": # pragma: no cover, application
  157. if ray.is_initialized():
  158. ray.shutdown()
  159. ray.init()
  160. app()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...