Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. """
  2. Train the model.
  3. Routine Listings
  4. ----------------
  5. get_params()
  6. Get the DVC stage parameters.
  7. train(input, output, model, model_params)
  8. Train model on feature matrix.
  9. """
  10. import sys
  11. import dask
  12. import dask.distributed
  13. import numpy as np
  14. from sklearn.ensemble import RandomForestClassifier
  15. import pickle
  16. import conf
  17. def get_params():
  18. """Get the DVC stage parameters."""
  19. return {
  20. 'classifier': RandomForestClassifier,
  21. 'n_estimators': 100,
  22. 'n_jobs': 2,
  23. 'random_state': 42}
  24. @dask.delayed
  25. def train(input, output, model, model_params):
  26. """Train model on feature matrix."""
  27. with open(input, 'rb') as fd:
  28. matrix = pickle.load(fd)
  29. labels = np.squeeze(matrix[:, 1].toarray())
  30. x = matrix[:, 2:]
  31. sys.stderr.write('Input matrix size {}\n'.format(matrix.shape))
  32. sys.stderr.write('X matrix size {}\n'.format(x.shape))
  33. sys.stderr.write('Y matrix size {}\n'.format(labels.shape))
  34. clf = model(**model_params)
  35. clf.fit(x, labels)
  36. with open(output, 'wb') as fd:
  37. pickle.dump(clf, fd)
  38. if __name__ == '__main__':
  39. client = dask.distributed.Client('localhost:8786')
  40. INPUT_TRAIN_MATRIX_PATH = conf.data_dir/'featurization'/'matrix-train.p'
  41. dvc_stage_name = __file__.strip('.py')
  42. STAGE_OUTPUT_PATH = conf.data_dir/dvc_stage_name
  43. conf.remote_mkdir(STAGE_OUTPUT_PATH).compute()
  44. OUTPUT_MODEL_PATH = STAGE_OUTPUT_PATH/'model.p'
  45. config = get_params()
  46. CLASSIFIER = config['classifier']
  47. N_ESTIMATORS = config['n_estimators']
  48. N_JOBS = config['n_jobs']
  49. RANDOM_STATE = config['random_state']
  50. train(INPUT_TRAIN_MATRIX_PATH, OUTPUT_MODEL_PATH, CLASSIFIER,
  51. {'n_estimators': N_ESTIMATORS, 'n_jobs': N_JOBS,
  52. 'random_state': RANDOM_STATE}).compute()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...