Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 887 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  1. import numpy as np
  2. from sklearn.ensemble import RandomForestClassifier
  3. import sys
  4. import conf
  5. try: import cPickle as pickle # python2
  6. except: import pickle # python3
  7. if len(sys.argv) != 2:
  8. sys.stderr.write('Arguments error. Usage:\n')
  9. sys.stderr.write('\tpython train_model.py INPUT_MATRIX_FILE SEED OUTPUT_MODEL_FILE\n')
  10. sys.exit(1)
  11. input = conf.train_matrix
  12. output = conf.model
  13. seed = int(sys.argv[1])
  14. with open(input, 'rb') as fd:
  15. matrix = pickle.load(fd)
  16. labels = np.squeeze(matrix[:, 1].toarray())
  17. x = matrix[:, 2:]
  18. sys.stderr.write('Input matrix size {}\n'.format(matrix.shape))
  19. sys.stderr.write('X matrix size {}\n'.format(x.shape))
  20. sys.stderr.write('Y matrix size {}\n'.format(labels.shape))
  21. clf = RandomForestClassifier(n_estimators=700, n_jobs=6, random_state=seed)
  22. clf.fit(x, labels)
  23. with open(output, 'wb') as fd:
  24. pickle.dump(clf, fd)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...