Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

shared.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. import os
  2. data_dir = os.path.join(os.path.dirname(__file__), '../data/')
  3. outputs_dir = os.path.join(os.path.dirname(__file__), '../outputs/')
  4. raw_data = os.path.join(data_dir, 'CrossValidated-Questions.csv')
  5. raw_ds_data = os.path.join(data_dir, 'DataScience-Posts.csv')
  6. train_data = os.path.join(data_dir, 'train-raw.csv')
  7. test_data = os.path.join(data_dir, 'test-raw.csv')
  8. train_processed = os.path.join(data_dir, 'train-processed.pkl')
  9. test_processed = os.path.join(data_dir, 'test-processed.pkl')
  10. classifier_pkl = os.path.join(outputs_dir, 'classifier.pkl')
  11. pipeline_pkl = os.path.join(outputs_dir, 'pipeline.pkl')
  12. col_id = 'Id'
  13. col_text = 'Text'
  14. col_title = 'Title'
  15. col_body = 'Body'
  16. col_tags = 'Tags'
  17. col_label = 'IsTaggedML'
  18. extra_feature_cols = ['Score','ViewCount','AnswerCount','CommentCount','FavoriteCount']
  19. text_cols = [col_title, col_body]
  20. text_len_col = 'Text_Len'
  21. all_raw_cols = ['Id','Title','Body','Tags','CreationDate','Score','ViewCount','AnswerCount','CommentCount','FavoriteCount','IsTaggedML']
  22. def save(obj, path):
  23. import pickle
  24. with open(path, 'wb') as f:
  25. pickle.dump(obj, f)
  26. def load(path):
  27. import pickle
  28. with open(path, 'rb') as f:
  29. return pickle.load(f)
  30. def load_labels(path=train_data):
  31. import pandas as pd
  32. return pd.read_csv(path, usecols=[col_label])[col_label]
  33. def compute_metrics(clf, X, y, prefix):
  34. from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, roc_auc_score, balanced_accuracy_score, auc, precision_recall_curve
  35. X = X.sparse.to_coo()
  36. preds = clf.predict(X)
  37. probas = clf.predict_proba(X)[:,1]
  38. pr_curve = precision_recall_curve(y, probas)
  39. return {
  40. f"{prefix}_accuracy_score": accuracy_score(y, preds),
  41. f"{prefix}_f1_score": f1_score(y, preds),
  42. f"{prefix}_recall_score": recall_score(y, preds),
  43. f"{prefix}_precision_score": precision_score(y, preds),
  44. f"{prefix}_roc_auc_score": roc_auc_score(y, probas),
  45. f"{prefix}_pr_auc_score": auc(pr_curve[1], pr_curve[0]),
  46. f"{prefix}_balanced_accuracy_score": balanced_accuracy_score(y, preds)
  47. }
  48. def load_data_and_labels(processed_path):
  49. X = load(processed_path)
  50. y = X[col_label]
  51. X = X.drop(columns=[col_id,col_label])
  52. return X, y
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...