Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  1. import os
  2. import sys
  3. import pickle
  4. import yaml
  5. from sklearn.linear_model import LogisticRegression
  6. from sklearn.ensemble import RandomForestClassifier
  7. from sklearn.svm import SVC
  8. from sklearn.neighbors import KNeighborsClassifier
  9. from sklearn.model_selection import GridSearchCV
  10. from sklearn.pipeline import Pipeline
  11. from sklearn.model_selection import train_test_split
  12. from sklearn.multiclass import OneVsRestClassifier
  13. from transform import preprocess
  14. from parameters import parameters
  15. params = yaml.safe_load(open('params.yaml'))['train']
  16. seed = params['seed']
  17. split = params['split']
  18. model_type = params['model']
  19. preprocessor = preprocess()
  20. if model_type == 'random forest':
  21. ml_model = RandomForestClassifier()
  22. elif model_type == 'support vector machine':
  23. ml_model = SVC()
  24. elif model_type == 'logistic regression':
  25. ml_model = LogisticRegression()
  26. elif model_type == 'kneighbors':
  27. ml_model = KNeighborsClassifier()
  28. input = sys.argv[1] #data/prepared
  29. output = os.path.join('model', sys.argv[2]) #model.pkl
  30. def split_dataset(df):
  31. X = df[df.columns[:-1]]
  32. y = df['Star type']
  33. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=split, random_state=seed)
  34. return X_train, X_test, y_train, y_test
  35. def fit(ml_model, model_type):
  36. clf = Pipeline(
  37. steps=[
  38. ('preparation', preprocessor),
  39. ('classifier', OneVsRestClassifier(ml_model))
  40. ]
  41. )
  42. # print(clf.get_params().keys())
  43. param_grid = parameters(model_type)
  44. grid_search = GridSearchCV(clf, param_grid, cv=10)
  45. model = grid_search.fit(X_train, y_train)
  46. return model
  47. os.makedirs('model', exist_ok=True)
  48. with open(os.path.join(input, 'data.pkl'), 'rb') as fd:
  49. df = pickle.load(fd)
  50. with open(output, 'wb') as fd:
  51. X_train, X_test, y_train, y_test = split_dataset(df)
  52. pickle.dump(fit(ml_model, model_type), fd)
  53. # python src/train.py data/prepared model.pkl
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...