Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  1. import pandas as pd
  2. from xgboost import XGBClassifier
  3. from sklearn.ensemble import RandomForestClassifier
  4. from sklearn.neighbors import KNeighborsClassifier
  5. from sklearn.tree import DecisionTreeClassifier
  6. from sklearn.metrics import f1_score, make_scorer
  7. from typing import Dict, Text
  8. from sklearn.model_selection import GridSearchCV
  9. class UnsupportedClassifier(Exception):
  10. def __init__(self, estimator_name):
  11. self.msg = f'Unsupported estimator {estimator_name}'
  12. super().__init__(self.msg)
  13. def get_supported_estimator() -> Dict:
  14. """
  15. Returns:
  16. Dict: supported classifiers
  17. """
  18. return {
  19. 'random_forest': RandomForestClassifier,
  20. 'decision_tree': DecisionTreeClassifier,
  21. 'knn': KNeighborsClassifier,
  22. 'xgb': XGBClassifier
  23. }
  24. def model(df: pd.DataFrame, target_column: Text,
  25. estimator_name: Text, param_grid: Dict, cv: int):
  26. """Train model.
  27. Args:
  28. df {pandas.DataFrame}: dataset
  29. target_column {Text}: target column name
  30. estimator_name {Text}: estimator name
  31. param_grid {Dict}: grid parameters
  32. cv {int}: cross-validation value
  33. Returns:
  34. trained model
  35. """
  36. estimators = get_supported_estimator()
  37. if estimator_name not in estimators.keys():
  38. raise UnsupportedClassifier(estimator_name)
  39. estimator = estimators[estimator_name]()
  40. f1_scorer = make_scorer(f1_score, average='weighted')
  41. clf = GridSearchCV(estimator=estimator,
  42. param_grid=dict(param_grid),
  43. cv=cv,
  44. verbose=1,
  45. scoring=f1_scorer)
  46. # Get X and Y
  47. y_train = df.loc[:, target_column].values.astype('int32')
  48. X_train = df.drop(target_column, axis=1).values.astype('float32')
  49. clf.fit(X_train, y_train)
  50. return clf
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...