Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tune.py 9.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
  1. import sys
  2. from dataclasses import replace
  3. from pathlib import Path
  4. from typing import *
  5. import click
  6. def get_local_path():
  7. debug_local = True #to use local version
  8. local = (Path(".") / "yspecies").resolve()
  9. if debug_local and local.exists():
  10. #sys.path.insert(0, Path(".").as_posix())
  11. sys.path.insert(0, local.as_posix())
  12. print("extending pathes with local yspecies")
  13. print(sys.path)
  14. return local
  15. @click.group()
  16. @click.option('--debug/--no-debug', default=False)
  17. def cli(debug):
  18. click.echo('Debug mode is %s' % ('on' if debug else 'off'))
  19. def tune_imp(trait: str, metrics: str, trials: int, folds: int, hold_outs: int, repeats: int, not_validated_species: Union[bool, List[str]], threads: int, debug_local: bool):
  20. from loguru import logger
  21. local = get_local_path()
  22. from pathlib import Path
  23. from yspecies.config import Locations
  24. locations: Locations = Locations("./") if Path("./data").exists() else Locations("../")
  25. logger.add(locations.logs / "tune_errors.log", backtrace=True, diagnose=True)
  26. logger.add(locations.logs / "tune.log", rotation="12:00") # New file is created each day at noon
  27. logger.info(f"starting hyper-parameters optimization script with {trials} trials, {folds} folds and {hold_outs} hold outs!")
  28. importance_type = "split"
  29. life_history = ["lifespan", "mass_kg", "mtGC", "metabolic_rate", "temperature", "gestation_days"]
  30. from yspecies.config import DataLoader
  31. from yspecies.preprocess import FeatureSelection
  32. import pprint
  33. pp = pprint.PrettyPrinter(indent=4)
  34. # ### Loading data ###
  35. # Let's load data from species/genes/expressions selected by select_samples.py notebook
  36. default_selection = FeatureSelection(
  37. samples = ["tissue","species"], #samples metadata to include
  38. species = [], #species metadata other then Y label to include
  39. exclude_from_training = ["species"], #exclude some fields from LightGBM training
  40. to_predict = trait, #column to predict
  41. categorical = ["tissue"],
  42. select_by = "shap",
  43. importance_type = importance_type,
  44. feature_perturbation = "tree_path_dependent"
  45. )
  46. loader = DataLoader(locations, default_selection)
  47. selections = loader.load_life_history()
  48. to_select = selections[trait]
  49. optimize(folds, hold_outs, locations, metrics, repeats, to_select, trait, trials)
  50. def optimize(folds, hold_outs, locations, metrics, repeats, to_select, trait, trials):
  51. from sklearn.pipeline import Pipeline
  52. from yspecies.workflow import Repeat, Collect
  53. from yspecies.preprocess import FeatureSelection, DataExtractor
  54. from yspecies.partition import DataPartitioner, PartitionParameters
  55. from yspecies.selection import ShapSelector
  56. from yspecies.tuning import Tune
  57. from yspecies.explanations import FeatureSummary, FeatureResults
  58. import optuna
  59. from optuna import Trial
  60. # ## Setting up ShapSelector ##
  61. # Deciding on selection parameters (which fields to include, exclude, predict)
  62. partition_params = PartitionParameters(folds, hold_outs, 2, 42)
  63. selection = FeatureSelection(
  64. samples=["tissue", "species"], # samples metadata to include
  65. species=[], # species metadata other then Y label to include
  66. exclude_from_training=["species"], # exclude some fields from LightGBM training
  67. to_predict=trait, # column to predict
  68. categorical=["tissue"],
  69. select_by="shap",
  70. importance_type="split"
  71. )
  72. url = f'sqlite:///' + str((locations.interim.optimization / f"{trait}.sqlite").absolute())
  73. print('loading (if exists) study from ' + url)
  74. storage = optuna.storages.RDBStorage(
  75. url=url
  76. # engine_kwargs={'check_same_thread': False}
  77. )
  78. study = optuna.multi_objective.study.create_study(directions=['maximize', 'minimize', 'maximize'], storage=storage,
  79. study_name=f"{trait}_{metrics}", load_if_exists=True)
  80. study.get_pareto_front_trials()
  81. def objective_parameters(trial: Trial) -> dict:
  82. return {
  83. 'objective': 'regression',
  84. 'metric': {'mae', 'mse', 'huber'},
  85. 'verbosity': -1,
  86. 'boosting_type': trial.suggest_categorical('boosting_type', ['dart', 'gbdt']),
  87. 'lambda_l1': trial.suggest_uniform('lambda_l1', 0.01, 3.0),
  88. 'lambda_l2': trial.suggest_uniform('lambda_l2', 0.01, 3.0),
  89. 'max_leaves': trial.suggest_int("max_leaves", 15, 25),
  90. 'max_depth': trial.suggest_int('max_depth', 3, 8),
  91. 'feature_fraction': trial.suggest_uniform('feature_fraction', 0.3, 1.0),
  92. 'bagging_fraction': trial.suggest_uniform('bagging_fraction', 0.3, 1.0),
  93. 'learning_rate': trial.suggest_uniform('learning_rate', 0.01, 0.1),
  94. 'min_data_in_leaf': trial.suggest_int('min_data_in_leaf', 3, 8),
  95. 'drop_rate': trial.suggest_uniform('drop_rate', 0.1, 0.3),
  96. "verbose": -1
  97. }
  98. optimization_parameters = objective_parameters
  99. from yspecies.workflow import SplitReduce
  100. def side(i: int):
  101. print(i)
  102. return i
  103. prepare_partition = SplitReduce(
  104. outputs=DataPartitioner(),
  105. split=lambda x: [(x[0], replace(partition_params, seed=side(x[2])))],
  106. reduce=lambda x, output: (output[0], x[1])
  107. )
  108. partition_and_cv = Pipeline(
  109. [
  110. ("prepare partition", prepare_partition),
  111. ("shap_computation", ShapSelector()) # ('crossvalidator', CrossValidator())
  112. ]
  113. )
  114. def get_objectives(results: List[FeatureResults]) -> Tuple[float, float, float]:
  115. summary = FeatureSummary(results)
  116. return (summary.metrics_average.R2, summary.metrics_average.huber, summary.kendall_tau_abs_mean)
  117. partition_and_cv_repeat = Pipeline([
  118. ("repeat_cv_pipe", Repeat(partition_and_cv, repeats, lambda x, i: [x[0], x[1], i])),
  119. ("collect_mean", Collect(fold=lambda outputs: get_objectives(outputs)))
  120. ]
  121. )
  122. p = Pipeline([
  123. ('extractor', DataExtractor()),
  124. ('tune', Tune(partition_and_cv_repeat, study=study, n_trials=trials, parameters_space=optimization_parameters))
  125. ])
  126. from yspecies.tuning import MultiObjectiveResults
  127. results: MultiObjectiveResults = p.fit_transform(to_select)
  128. best = results.best_trials
  129. import json
  130. for i, t in enumerate(best):
  131. trait_path = locations.metrics.optimization / trait
  132. if not trait_path.exists():
  133. trait_path.mkdir()
  134. path = trait_path / f"{str(i)}.json"
  135. print(f"writing parameters to {path}")
  136. with open(path, 'w') as f:
  137. params = t.params
  138. values = t.values
  139. to_write = {"number": t.number, "params": params,
  140. "metrics": {"R2": values[0], "huber": values[1], "kendall_tau": values[2]}}
  141. json.dump(to_write, f, sort_keys=True, indent=4)
  142. print(f"FINISHED HYPER OPTIMIZING {trait}")
  143. #@click.group(invoke_without_command=True)
  144. @cli.command()
  145. @click.option('--trait', default="lifespan", help='trait name')
  146. @click.option('--metrics', default="r2_huber_kendall", help='metrics names')
  147. @click.option('--trials', default=200, help='Number of trials in hyper optimization')
  148. @click.option('--folds', default=5, help='Number of folds in cross-validation')
  149. @click.option('--hold_outs', default=1, help='Number of hold outs in cross-validation')
  150. @click.option('--repeats', default=5, help="number of times to repeat validation")
  151. @click.option('--not_validated_species', default=True, help="not_validated_species")
  152. @click.option('--threads', default=1, help="number of threads (1 by default). If you put -1 it will try to utilize all cores, however it can be dangerous memorywise")
  153. @click.option('--debug_local', default=True, help="debug local")
  154. def tune(trait: str, metrics: str, trials: int, folds: int, hold_outs: int, repeats: int, not_validated_species: Union[bool, List[str]], threads: int, debug_local: bool):
  155. return tune_imp(trait, metrics, trials, folds, hold_outs, repeats, not_validated_species, threads, debug_local)
  156. @cli.command()
  157. @click.option('--life_history', default=["lifespan", "mass_kg", "gestation_days", "mtGC", "metabolic_rate", "temperature"], help='life history list')
  158. @click.option('--metrics', default="r2_huber_kendall", help='metrics names')
  159. @click.option('--trials', default=10, help='Number of trials in hyper optimization')
  160. @click.option('--folds', default=5, help='Number of folds in cross-validation')
  161. @click.option('--hold_outs', default=1, help='Number of hold outs in cross-validation')
  162. @click.option('--repeats', default=5, help="number of times to repeat validation")
  163. @click.option('--not_validated_species', default=True, help="not_validated_species")
  164. @click.option('--threads', default=1, help="number of threads (1 by default). If you put -1 it will try to utilize all cores, however it can be dangerous memorywise")
  165. @click.option('--debug_local', default=True, help="debug local")
  166. def tune_all(life_history: List[str],
  167. metrics: str,
  168. trials: int,
  169. folds: int,
  170. hold_outs: int,
  171. repeats: int,
  172. not_validated_species: Union[bool, List[str]],
  173. threads: int,
  174. debug_local: bool):
  175. for trait in life_history:
  176. print(f"tunning {trait} with {trials}")
  177. tune_imp(trait, metrics, trials, folds, hold_outs, repeats, not_validated_species, threads, debug_local)
  178. if __name__ == "__main__":
  179. cli()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...