Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  1. import argparse
  2. import pandas as pd
  3. from sklearn import tree
  4. import pickle
  5. if __name__ == '__main__':
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument("--train_dataset", type=str, required=True, help="Path to train dataset")
  8. parser.add_argument("--test_dataset", type=str, required=True, help="Path to train dataset")
  9. parser.add_argument("--target_column", type=str, required=True, help="Column with classname")
  10. parser.add_argument("--model_path", type=str, required=True, help="Path where the model will be saved")
  11. parser.add_argument("--predictions_path", type=str, required=True, help="Path where model predictions will be saved")
  12. parser.add_argument("--max_depth", type=int, default=1, help="Max tree depth")
  13. args = parser.parse_args()
  14. # Load train dataset
  15. df_train = pd.read_csv(args.train_dataset)
  16. y_train = df_train[args.target_column]
  17. xs_train = df_train.drop(args.target_column, axis=1)
  18. # Load test dataset
  19. df_test = pd.read_csv(args.test_dataset)
  20. xs_test = df_test.drop(args.target_column, axis=1)
  21. # Train
  22. clf = tree.DecisionTreeClassifier(max_depth=args.max_depth)
  23. clf.fit(xs_train, y_train)
  24. # Save model
  25. with open(args.model_path, 'wb') as f:
  26. pickle.dump(clf, f)
  27. # Save predictions
  28. predictions = pd.DataFrame()
  29. predictions[args.target_column] = clf.predict(xs_test)
  30. predictions.to_csv(args.predictions_path, index=False)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...