Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. import dagshub
  2. import os
  3. import pandas as pd
  4. import yaml
  5. import re
  6. import numpy as np
  7. import joblib
  8. from scipy.sparse.dia import dia_matrix
  9. from sklearn.feature_extraction.text import TfidfVectorizer
  10. from sklearn.linear_model import SGDClassifier
  11. from reddit_utils import calculate_metrics, prepare_log
  12. import reddit_utils
  13. with open(r"./general_params.yml") as f:
  14. params = yaml.safe_load(f)
  15. with open(r"./training_params.yml") as f:
  16. training_params = yaml.safe_load(f)
  17. CHUNK_SIZE = params["chunk_size"]
  18. TARGET_LABEL = params["target_col"]
  19. MODEL_TYPE_TEXT = "model_text"
  20. MODEL_TYPE_NUM_CAT = "model_num_cat"
  21. MODEL_TYPE_OTHER = ""
  22. MODEL_TYPE = (
  23. MODEL_TYPE_TEXT
  24. if training_params["use_text_cols"]
  25. else MODEL_TYPE_NUM_CAT
  26. if training_params["use_number_category_cols"]
  27. else MODEL_TYPE_OTHER
  28. )
  29. TEST_DF_PATH = "rML-test.csv"
  30. def get_remote_gs_wfs():
  31. print("Retreiving location of remote working file system...")
  32. stream = os.popen("dvc remote list --local")
  33. output = stream.read()
  34. remote_wfs_loc = output.split("\t")[1].split("\n")[0]
  35. return remote_wfs_loc
  36. def load_transform_and_eval(remote_wfs, model_type=None, random_state=42):
  37. print("loading transformer and model...")
  38. if model_type == MODEL_TYPE_TEXT:
  39. model = joblib.load(reddit_utils.MODEL_PATH)
  40. tfidf = joblib.load(reddit_utils.TFIDF_PATH)
  41. else:
  42. # TODO
  43. return
  44. y_proba = np.array([])
  45. y_pred = np.array([])
  46. y = np.array([])
  47. print("Loading test data and testing model...")
  48. for i, chunk in enumerate(
  49. pd.read_csv(os.path.join(remote_wfs, TEST_DF_PATH), chunksize=CHUNK_SIZE)
  50. ):
  51. print(f"Testing on chunk {i+1}...")
  52. test_tfidf = tfidf.transform(chunk["title_and_body"].values.astype("U"))
  53. y_proba = np.concatenate((y_pred, model.predict_proba(test_tfidf)[:, 1]))
  54. y_pred = np.concatenate((y_pred, model.predict(test_tfidf)))
  55. y = np.concatenate((y, chunk[TARGET_LABEL]))
  56. print("Calculating metrics")
  57. print(np.unique(y_proba), np.unique(y_pred), np.unique(y))
  58. metrics = calculate_metrics(y_pred, y_proba, y)
  59. print("Logging metrics...")
  60. with dagshub.dagshub_logger(should_log_hparams=False) as logger:
  61. logger.log_metrics(prepare_log(metrics, "test"))
  62. if __name__ == "__main__":
  63. remote_wfs = get_remote_gs_wfs()
  64. load_transform_and_eval(remote_wfs, MODEL_TYPE)
  65. print("Model evaluation done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Mr.abuhe

commented in commit2e0f2b35fbon branch master

11 months ago Outdated

Ini tautan akun fb saya tapi tapi orang lain mengambil nya tolong biarkan saya masuk ke tautan ini. https://www.facebook.com/profile.php?id=100087038453254&mibextid=ZbWKwL

Mr.abuhe

commented in commit2e0f2b35fbon branch master

11 months ago Outdated

id=100087038453254&mibextid=ZbWKwL@econsav838

Loading...