Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

run_grid_search.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. import mlflow
  2. from itertools import product
  3. import warnings
  4. from model.model_training import train_random_forest_model
  5. from model.utils import model_metrics
  6. import argparse
  7. import json
  8. def grid_search_random_forest(name_experiment):
  9. # this function runs a grid search over the hyper-parameters specified below
  10. max_depth = [3, 6]
  11. criterion = ['gini', 'entropy']
  12. min_samples_leaf = [5, 10]
  13. n_estimators = [50, 100]
  14. parameters = product(max_depth, criterion, min_samples_leaf, n_estimators)
  15. parameters_list = list(parameters)
  16. print('Number of experiments:', len(parameters_list))
  17. # Hyperparameter search
  18. results = []
  19. best_param = None
  20. best_f1 = 0.0
  21. warnings.filterwarnings('ignore')
  22. for i, param in enumerate(parameters_list):
  23. print('Running experiment number ', i)
  24. with mlflow.start_run(run_name=name_experiment):
  25. # Tell mlflow to log the following parameters for the experiments dashboard
  26. mlflow.log_param('depth', param[0])
  27. mlflow.log_param('criterion', param[1])
  28. mlflow.log_param('minsamplesleaf', param[2])
  29. mlflow.log_param('nestimators', param[3])
  30. try:
  31. parameters = dict(n_estimators=param[3],
  32. max_depth=param[0],
  33. criterion=param[1],
  34. min_sample_leaf=param[2])
  35. clf = train_random_forest_model(data_path='./data/adult_training.csv',
  36. parameters=parameters)
  37. metrics = model_metrics(clf, data_path='./data/adult_validation.csv')
  38. # Tell mlflow to log the following metrics
  39. mlflow.log_metric("precision", metrics['>50K']['precision'])
  40. mlflow.log_metric("F1", metrics['>50K']['f1-score'])
  41. # Store this artifact for each run
  42. json.dump(metrics, open("metrics.json", "w"))
  43. mlflow.log_artifact('./metrics.json')
  44. # save the best experiment yet (in terms of precision)
  45. if metrics['>50K']['f1-score'] > best_f1:
  46. best_param = parameters
  47. best_f1 = metrics['>50K']['f1-score']
  48. results.append([param, metrics['>50K']['f1-score']])
  49. except ValueError:
  50. print('bad parameter combination:', param)
  51. continue
  52. mlflow.end_run()
  53. print('Best F1 was:', best_f1)
  54. print('Using the following parameters')
  55. print(best_param)
  56. return results, best_param
  57. if __name__ == '__main__':
  58. parser = argparse.ArgumentParser()
  59. parser.add_argument("--name", help="experiment_name")
  60. args, leftovers = parser.parse_known_args()
  61. results, best_param = grid_search_random_forest(args.name)
  62. json.dump(best_param, open("best_params.json", "w"))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...