Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 857 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  1. import sys
  2. import os
  3. import numpy as np
  4. from sklearn.ensemble import RandomForestClassifier
  5. try:
  6. import cPickle as pickle
  7. except ImportError:
  8. import pickle
  9. if len(sys.argv) != 3:
  10. sys.stderr.write('Arguments error. Usage:\n')
  11. sys.stderr.write('\tpython train.py features model\n')
  12. sys.exit(1)
  13. input = sys.argv[1]
  14. output = sys.argv[2]
  15. seed = 20170426
  16. with open(os.path.join(input, 'train.pkl'), 'rb') as fd:
  17. matrix = pickle.load(fd)
  18. labels = np.squeeze(matrix[:, 1].toarray())
  19. x = matrix[:, 2:]
  20. sys.stderr.write('Input matrix size {}\n'.format(matrix.shape))
  21. sys.stderr.write('X matrix size {}\n'.format(x.shape))
  22. sys.stderr.write('Y matrix size {}\n'.format(labels.shape))
  23. clf = RandomForestClassifier(n_estimators=100, n_jobs=2, random_state=seed)
  24. clf.fit(x, labels)
  25. with open(output, 'wb') as fd:
  26. pickle.dump(clf, fd)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...