Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  1. import pandas as pd
  2. from sklearn.feature_extraction.text import TfidfVectorizer
  3. from sklearn.ensemble import RandomForestClassifier
  4. from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, precision_score, recall_score, \
  5. f1_score
  6. from sklearn.model_selection import train_test_split
  7. import joblib
  8. import os
  9. import dagshub
  10. # Consts
  11. CLASS_LABEL = 'MachineLearning'
  12. train_df_path = 'data/train.csv'
  13. test_df_path = 'data/test.csv'
  14. def fit_tfidf(train_df, test_df):
  15. tfidf = TfidfVectorizer(max_features=25000, ngram_range=(1, 2))
  16. tfidf.fit(train_df['Text'])
  17. train_tfidf = tfidf.transform(train_df['Text'])
  18. test_tfidf = tfidf.transform(test_df['Text'])
  19. return train_tfidf, test_tfidf, tfidf
  20. def fit_model(train_X, train_y, random_state=42):
  21. clf_tfidf = RandomForestClassifier(random_state=random_state, max_depth=50, class_weight='balanced')
  22. clf_tfidf.fit(train_X, train_y)
  23. return clf_tfidf
  24. def eval_model(clf, X, y):
  25. y_proba = clf.predict_proba(X)[:, 1]
  26. y_pred = clf.predict(X)
  27. return {
  28. 'roc_auc': roc_auc_score(y, y_proba),
  29. 'average_precision': average_precision_score(y, y_proba),
  30. 'accuracy': accuracy_score(y, y_pred),
  31. 'precision': precision_score(y, y_pred),
  32. 'recall': recall_score(y, y_pred),
  33. 'f1': f1_score(y, y_pred),
  34. }
  35. # Prepare a dictionary of either hyperparams or metrics for logging.
  36. def prepare_log(d, prefix=''):
  37. if prefix:
  38. prefix = f'{prefix}__'
  39. # Ensure all logged values are suitable for logging - complex objects aren't supported.
  40. def sanitize(value):
  41. return value if value is None or type(value) in [str, int, float, bool] else str(value)
  42. return {f'{prefix}{k}': sanitize(v) for k, v in d.items()}
  43. def train():
  44. print('Loading data...')
  45. train_df = pd.read_csv(train_df_path)
  46. test_df = pd.read_csv(test_df_path)
  47. # Create outputs directory if it doesn't exist
  48. os.mkdir("outputs")
  49. with dagshub.dagshub_logger() as logger:
  50. print('Fitting TFIDF...')
  51. train_tfidf, test_tfidf, tfidf = fit_tfidf(train_df, test_df)
  52. print('Saving TFIDF object...')
  53. joblib.dump(tfidf, 'outputs/tfidf.joblib')
  54. logger.log_hyperparams(prepare_log(tfidf.get_params(), 'tfidf'))
  55. print('Training model...')
  56. train_y = train_df[CLASS_LABEL]
  57. model = fit_model(train_tfidf, train_y)
  58. print('Saving trained model...')
  59. joblib.dump(model, 'outputs/model.joblib')
  60. logger.log_hyperparams(model_class=type(model).__name__)
  61. logger.log_hyperparams(prepare_log(model.get_params(), 'model'))
  62. print('Evaluating model...')
  63. train_metrics = eval_model(model, train_tfidf, train_y)
  64. print('Train metrics:')
  65. print(train_metrics)
  66. logger.log_metrics(prepare_log(train_metrics, 'train'))
  67. test_metrics = eval_model(model, test_tfidf, test_df[CLASS_LABEL])
  68. print('Test metrics:')
  69. print(test_metrics)
  70. logger.log_metrics(prepare_log(test_metrics, 'test'))
  71. if __name__ == '__main__':
  72. train()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...