Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. import dagshub
  2. import os
  3. import pandas as pd
  4. import yaml
  5. import re
  6. import numpy as np
  7. import joblib
  8. from scipy.sparse.dia import dia_matrix
  9. from sklearn.feature_extraction.text import TfidfVectorizer
  10. from sklearn.linear_model import SGDClassifier
  11. from reddit_utils import calculate_metrics, prepare_log
  12. import reddit_utils
  13. with open(r"./general_params.yml") as f:
  14. params = yaml.safe_load(f)
  15. with open(r"./training_params.yml") as f:
  16. training_params = yaml.safe_load(f)
  17. CHUNK_SIZE = params["chunk_size"]
  18. TARGET_LABEL = params["target_col"]
  19. MODEL_TYPE_TEXT = "model_text"
  20. MODEL_TYPE_NUM_CAT = "model_num_cat"
  21. MODEL_TYPE_OTHER = ""
  22. MODEL_TYPE = (
  23. MODEL_TYPE_TEXT
  24. if training_params["use_text_cols"]
  25. else MODEL_TYPE_NUM_CAT
  26. if training_params["use_number_category_cols"]
  27. else MODEL_TYPE_OTHER
  28. )
  29. TEST_DF_PATH = "rML-test.csv"
  30. def get_remote_gs_wfs():
  31. print("Retreiving location of remote working file system...")
  32. stream = os.popen("dvc remote list --local")
  33. output = stream.read()
  34. remote_wfs_loc = output.split("\t")[1].split("\n")[0]
  35. return remote_wfs_loc
  36. def load_transform_and_eval(remote_wfs, model_type=None, random_state=42):
  37. print("loading transformer and model...")
  38. if model_type == MODEL_TYPE_TEXT:
  39. model = joblib.load(reddit_utils.MODEL_PATH)
  40. tfidf = joblib.load(reddit_utils.TFIDF_PATH)
  41. else:
  42. # TODO
  43. return
  44. y_proba = np.array([])
  45. y_pred = np.array([])
  46. y = np.array([])
  47. print("Loading test data and testing model...")
  48. for i, chunk in enumerate(
  49. pd.read_csv(os.path.join(remote_wfs, TEST_DF_PATH), chunksize=CHUNK_SIZE)
  50. ):
  51. print(f"Testing on chunk {i+1}...")
  52. test_tfidf = tfidf.transform(chunk["title_and_body"].values.astype("U"))
  53. y_proba = np.concatenate((y_pred, model.predict_proba(test_tfidf)[:, 1]))
  54. y_pred = np.concatenate((y_pred, model.predict(test_tfidf)))
  55. y = np.concatenate((y, chunk[TARGET_LABEL]))
  56. print("Calculating metrics")
  57. print(np.unique(y_proba), np.unique(y_pred), np.unique(y))
  58. metrics = calculate_metrics(y_pred, y_proba, y)
  59. print("Logging metrics...")
  60. with dagshub.dagshub_logger(should_log_hparams=False) as logger:
  61. logger.log_metrics(prepare_log(metrics, "test"))
  62. if __name__ == "__main__":
  63. remote_wfs = get_remote_gs_wfs()
  64. load_transform_and_eval(remote_wfs, MODEL_TYPE)
  65. print("Model evaluation done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...