Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  1. import os
  2. import pickle
  3. import sys
  4. import numpy as np
  5. import yaml
  6. from sklearn.ensemble import RandomForestClassifier
  7. def train(seed, n_est, min_split, matrix):
  8. """
  9. Train a random forest classifier.
  10. Args:
  11. seed (int): Random seed.
  12. n_est (int): Number of trees in the forest.
  13. min_split (int): Minimum number of samples required to split an internal node.
  14. matrix (scipy.sparse.csr_matrix): Input matrix.
  15. Returns:
  16. sklearn.ensemble.RandomForestClassifier: Trained classifier.
  17. """
  18. labels = np.squeeze(matrix[:, 1].toarray())
  19. x = matrix[:, 2:]
  20. sys.stderr.write("Input matrix size {}\n".format(matrix.shape))
  21. sys.stderr.write("X matrix size {}\n".format(x.shape))
  22. sys.stderr.write("Y matrix size {}\n".format(labels.shape))
  23. clf = RandomForestClassifier(
  24. n_estimators=n_est, min_samples_split=min_split, n_jobs=2, random_state=seed
  25. )
  26. clf.fit(x, labels)
  27. return clf
  28. def main():
  29. params = yaml.safe_load(open("params.yaml"))["train"]
  30. if len(sys.argv) != 3:
  31. sys.stderr.write("Arguments error. Usage:\n")
  32. sys.stderr.write("\tpython train.py features model\n")
  33. sys.exit(1)
  34. input = sys.argv[1]
  35. output = sys.argv[2]
  36. seed = params["seed"]
  37. n_est = params["n_est"]
  38. min_split = params["min_split"]
  39. # Load the data
  40. with open(os.path.join(input, "train.pkl"), "rb") as fd:
  41. matrix, _ = pickle.load(fd)
  42. clf = train(seed=seed, n_est=n_est, min_split=min_split, matrix=matrix)
  43. # Save the model
  44. with open(output, "wb") as fd:
  45. pickle.dump(clf, fd)
  46. if __name__ == "__main__":
  47. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...