Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

reddit_utils.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. import os
  2. from sklearn.metrics import (
  3. roc_auc_score,
  4. average_precision_score,
  5. accuracy_score,
  6. precision_score,
  7. recall_score,
  8. f1_score,
  9. )
  10. # ----- Cloud Details -----
  11. PROJECT_NAME = "talos-project"
  12. BIGQUERY_PROJECT = "project-talos"
  13. GCLOUD_CRED_ENV_VAR = "GOOGLE_APPLICATION_CREDENTIALS"
  14. # ----- Constants -----
  15. NUM_COL_NAMES = ["title_len", "body_len", "hour", "minute", "dayofweek", "dayofyear"]
  16. CAT_COL_NAMES = [
  17. "has_thumbnail",
  18. "flair_Clickbait",
  19. "flair_Discussion",
  20. "flair_Inaccurate",
  21. "flair_Misleading",
  22. "flair_News",
  23. "flair_None",
  24. "flair_Project",
  25. "flair_Research",
  26. "flair_Shameless Self Promo",
  27. ]
  28. TEXT_COL_NAME = ["title_and_body"]
  29. # ----- Paths -----
  30. MODELS_DIR = "./models"
  31. TFIDF_PATH = MODELS_DIR + "/tfidf.pkl"
  32. MODEL_PATH = MODELS_DIR + "/model.pkl"
  33. RAW_DF_PATH = "rML-raw-data.csv"
  34. TRAIN_DF_PATH = "rML-train.csv"
  35. TEST_DF_PATH = "rML-test.csv"
  36. # ----- Functions -----
  37. def calculate_metrics(y_pred, y_proba, y):
  38. return {
  39. "roc_auc": float(roc_auc_score(y, y_proba)),
  40. "average_precision": float(average_precision_score(y, y_proba)),
  41. "accuracy": float(accuracy_score(y, y_pred)),
  42. "precision": float(precision_score(y, y_pred)),
  43. "recall": float(recall_score(y, y_pred)),
  44. "f1": float(f1_score(y, y_pred)),
  45. }
  46. def get_remote_gs_wfs():
  47. print("Retreiving location of remote working file system...")
  48. stream = os.popen("dvc remote list --local")
  49. output = stream.read()
  50. remote_wfs_loc = output.split("\t")[1].split("\n")[0]
  51. return remote_wfs_loc
  52. # Prepare a dictionary of either hyperparams or metrics for logging.
  53. def prepare_log(d, prefix=""):
  54. if prefix:
  55. prefix = f"{prefix}__"
  56. # Ensure all logged values are suitable for logging - complex objects aren't supported.
  57. def sanitize(value):
  58. return (
  59. value
  60. if value is None or type(value) in [str, int, float, bool]
  61. else str(value)
  62. )
  63. return {f"{prefix}{k}": sanitize(v) for k, v in d.items()}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...