Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. import dagshub
  2. import os
  3. import pandas as pd
  4. import yaml
  5. import numpy as np
  6. import joblib
  7. from sklearn.linear_model import SGDClassifier
  8. from reddit_utils import calculate_metrics, prepare_log
  9. import reddit_utils
  10. with open(r"./general_params.yml") as f:
  11. params = yaml.safe_load(f)
  12. with open(r"./model_params.yml") as f:
  13. model_params = yaml.safe_load(f)
  14. CHUNK_SIZE = params["chunk_size"]
  15. TARGET_LABEL = params["target_col"]
  16. COLS_FOR_EVAL = []
  17. if model_params["use_text_cols"]:
  18. COLS_FOR_EVAL += reddit_utils.TEXT_COL_NAME
  19. if model_params["use_number_category_cols"]:
  20. COLS_FOR_EVAL += reddit_utils.NUM_COL_NAMES + reddit_utils.CAT_COL_NAMES
  21. TEST_DF_PATH = "rML-test.csv"
  22. def get_remote_gs_wfs():
  23. print("Retreiving location of remote working file system...")
  24. stream = os.popen("dvc remote list --local")
  25. output = stream.read()
  26. remote_wfs_loc = output.split("\t")[1].split("\n")[0]
  27. return remote_wfs_loc
  28. def load_transform_and_eval(remote_wfs, random_state=42):
  29. print("loading transformer and model...")
  30. model = joblib.load(reddit_utils.MODEL_PATH)
  31. y_proba = np.array([])
  32. y_pred = np.array([])
  33. y = np.array([])
  34. print("Loading test data and testing model...")
  35. for i, chunk in enumerate(
  36. pd.read_csv(os.path.join(remote_wfs, TEST_DF_PATH), chunksize=CHUNK_SIZE)
  37. ):
  38. print(f"Testing on chunk {i+1}...")
  39. df_X = chunk[COLS_FOR_EVAL]
  40. y_proba = np.concatenate((y_pred, model.predict_proba(df_X)[:, 1]))
  41. y_pred = np.concatenate((y_pred, model.predict(df_X)))
  42. y = np.concatenate((y, chunk[TARGET_LABEL]))
  43. print("Calculating metrics")
  44. metrics = calculate_metrics(y_pred, y_proba, y)
  45. print("Logging metrics...")
  46. with dagshub.dagshub_logger(should_log_hparams=False) as logger:
  47. logger.log_metrics(prepare_log(metrics, "test"))
  48. if __name__ == "__main__":
  49. remote_wfs = get_remote_gs_wfs()
  50. load_transform_and_eval(remote_wfs)
  51. print("Model evaluation done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Mr.abuhe

commented in commit2e0f2b35fbon branch master

11 months ago Outdated

Ini tautan akun fb saya tapi tapi orang lain mengambil nya tolong biarkan saya masuk ke tautan ini. https://www.facebook.com/profile.php?id=100087038453254&mibextid=ZbWKwL

Mr.abuhe

commented in commit2e0f2b35fbon branch master

11 months ago Outdated

id=100087038453254&mibextid=ZbWKwL@econsav838

Loading...