Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

main.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  1. import argparse
  2. import pandas as pd
  3. from sklearn.feature_extraction.text import TfidfVectorizer
  4. from sklearn.ensemble import RandomForestClassifier
  5. from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, precision_score, recall_score, \
  6. f1_score
  7. from sklearn.model_selection import train_test_split
  8. import joblib
  9. import dagshub
  10. # Consts
  11. CLASS_LABEL = 'MachineLearning'
  12. train_df_path = 'data/train.csv.zip'
  13. test_df_path = 'data/test.csv.zip'
  14. def feature_engineering(raw_df):
  15. df = raw_df.copy()
  16. df['CreationDate'] = pd.to_datetime(df['CreationDate'])
  17. df['CreationDate_Epoch'] = df['CreationDate'].astype('int64') // 10 ** 9
  18. df = df.drop(columns=['Id', 'Tags'])
  19. df['Title_Len'] = df.Title.str.len()
  20. df['Body_Len'] = df.Body.str.len()
  21. # Drop the correlated features
  22. df = df.drop(columns=['FavoriteCount'])
  23. df['Text'] = df['Title'].fillna('') + ' ' + df['Body'].fillna('')
  24. return df
  25. def fit_tfidf(train_df, test_df):
  26. tfidf = TfidfVectorizer(max_features=25000)
  27. tfidf.fit(train_df['Text'])
  28. train_tfidf = tfidf.transform(train_df['Text'])
  29. test_tfidf = tfidf.transform(test_df['Text'])
  30. return train_tfidf, test_tfidf, tfidf
  31. def fit_model(train_X, train_y, random_state=42):
  32. clf_tfidf = RandomForestClassifier(random_state=random_state)
  33. clf_tfidf.fit(train_X, train_y)
  34. return clf_tfidf
  35. def eval_model(clf, X, y):
  36. y_proba = clf.predict_proba(X)[:, 1]
  37. y_pred = clf.predict(X)
  38. return {
  39. 'roc_auc': roc_auc_score(y, y_proba),
  40. 'average_precision': average_precision_score(y, y_proba),
  41. 'accuracy': accuracy_score(y, y_pred),
  42. 'precision': precision_score(y, y_pred),
  43. 'recall': recall_score(y, y_pred),
  44. 'f1': f1_score(y, y_pred),
  45. }
  46. def split(random_state=42):
  47. print('Loading data...')
  48. df = pd.read_csv('data/CrossValidated-Questions.csv')
  49. df[CLASS_LABEL] = df['Tags'].str.contains('machine-learning').fillna(False)
  50. train_df, test_df = train_test_split(df, random_state=random_state, stratify=df[CLASS_LABEL])
  51. print('Saving split data...')
  52. train_df.to_csv(train_df_path)
  53. test_df.to_csv(test_df_path)
  54. # Prepare a dictionary of either hyperparams or metrics for logging.
  55. def prepare_log(d, prefix=''):
  56. if prefix:
  57. prefix = f'{prefix}__'
  58. # Ensure all logged values are suitable for logging - complex objects aren't supported.
  59. def sanitize(value):
  60. return value if value is None or type(value) in [str, int, float, bool] else str(value)
  61. return {f'{prefix}{k}': sanitize(v) for k, v in d.items()}
  62. def train():
  63. print('Loading data...')
  64. train_df = pd.read_csv(train_df_path)
  65. test_df = pd.read_csv(test_df_path)
  66. print('Engineering features...')
  67. train_df = feature_engineering(train_df)
  68. test_df = feature_engineering(test_df)
  69. with dagshub.dagshub_logger() as logger:
  70. print('Fitting TFIDF...')
  71. train_tfidf, test_tfidf, tfidf = fit_tfidf(train_df, test_df)
  72. print('Saving TFIDF object...')
  73. joblib.dump(tfidf, 'outputs/tfidf.joblib')
  74. logger.log_hyperparams(prepare_log(tfidf.get_params(), 'tfidf'))
  75. print('Training model...')
  76. train_y = train_df[CLASS_LABEL]
  77. model = fit_model(train_tfidf, train_y)
  78. print('Saving trained model...')
  79. joblib.dump(model, 'outputs/model.joblib')
  80. logger.log_hyperparams(model_class=type(model).__name__)
  81. logger.log_hyperparams(prepare_log(model.get_params(), 'model'))
  82. print('Evaluating model...')
  83. train_metrics = eval_model(model, train_tfidf, train_y)
  84. print('Train metrics:')
  85. print(train_metrics)
  86. logger.log_metrics(prepare_log(train_metrics, 'train'))
  87. test_metrics = eval_model(model, test_tfidf, test_df[CLASS_LABEL])
  88. print('Test metrics:')
  89. print(test_metrics)
  90. logger.log_metrics(prepare_log(test_metrics, 'test'))
  91. if __name__ == '__main__':
  92. parser = argparse.ArgumentParser()
  93. subparsers = parser.add_subparsers(title='Split or Train step:', dest='step')
  94. subparsers.required = True
  95. split_parser = subparsers.add_parser('split')
  96. split_parser.set_defaults(func=split)
  97. train_parser = subparsers.add_parser('train')
  98. train_parser.set_defaults(func=train)
  99. parser.parse_args().func()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...