Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_training.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  1. from sklearn.compose import ColumnTransformer
  2. from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder
  3. from sklearn.impute import SimpleImputer
  4. from sklearn.pipeline import Pipeline
  5. from sklearn.ensemble import RandomForestClassifier
  6. import json
  7. import os
  8. import joblib
  9. from data.datamanager import data_loader
  10. import clearbox_wrapper as cbw
  11. def train_random_forest_model(data_path: str,
  12. parameters=None):
  13. # This function trains a random folder classifier using the data specified by datapath
  14. # If parameters are not specified as argument look for params.json file, otherwise create default values
  15. if parameters is None:
  16. if os.path.exists('./params.json'):
  17. parameters = json.load(open("params.json", "r"))
  18. else:
  19. parameters = dict(n_estimators=100, max_depth=4, criterion='gini',
  20. min_sample_leaf=10)
  21. print(parameters)
  22. x_training, y_training = data_loader(data_path)
  23. # Scikit learn ColumnTransformer used to process ordinal and nominal data
  24. ordinal_features = x_training.select_dtypes(include="number").columns
  25. categorical_features = x_training.select_dtypes(include="object").columns
  26. ordinal_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),
  27. ('scaler', StandardScaler())])
  28. categorical_transformer = Pipeline(steps=[('onehot', OneHotEncoder(handle_unknown='ignore'))])
  29. x_encoder = ColumnTransformer(transformers=[('ord', ordinal_transformer, ordinal_features),
  30. ('cat', categorical_transformer, categorical_features)])
  31. rf_clf = RandomForestClassifier(n_estimators=parameters['n_estimators'],
  32. max_depth=parameters['max_depth'],
  33. criterion=parameters['criterion'],
  34. min_samples_leaf=parameters['min_sample_leaf'],
  35. random_state=42)
  36. rf_pipeline = Pipeline(steps=[("preprocessing", x_encoder), ("rf_model", rf_clf)])
  37. rf_pipeline.fit(x_training, y_training)
  38. # serialize model using joblib
  39. joblib.dump(rf_pipeline, 'model.pkl')
  40. cbw.save_model('./model_cbw', rf_pipeline)
  41. return rf_pipeline
  42. if __name__ == '__main__':
  43. train_random_forest_model('./data/adult_training.csv')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...