Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

predict_tag.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  1. import pickle
  2. import argparse
  3. import sys
  4. import pandas as pd
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import scipy
  8. from sklearn.metrics import *
  9. from sklearn.feature_extraction.text import TfidfVectorizer
  10. from sklearn.svm import SVC
  11. import dagshub
  12. sys.path.append('./models_scripts/common/')
  13. from tools import *
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("GRAPH_VER", help="version of the graph you want regex to label your CSV with", type=int)
  16. parser.add_argument("DATASET_PATH", help="path to your input CSV", type=str)
  17. parser.add_argument("MODEL", help="model to use", type=str)
  18. args = parser.parse_args()
  19. GRAPH_VER = args.GRAPH_VER
  20. DATASET_PATH = args.DATASET_PATH
  21. MODEL = args.MODEL # MODEL = 'svm'
  22. TASK = 'model validation' # 'model evaluation'
  23. MODEL_DIR = './models/{}_regex_graph_v{}.sav'.format(MODEL, GRAPH_VER)
  24. TFIDF_DIR = './models/tfidf_{}_graph_v{}.pickle'.format(MODEL, GRAPH_VER)
  25. CODE_COLUMN = 'code_block'
  26. TAGS_TO_PREDICT = get_graph_vertices(GRAPH_VER)
  27. SCRIPT_DIR = './predict_tag.ipynb'
  28. if __name__ == '__main__':
  29. df = load_data(DATASET_PATH)
  30. code_blocks = df[CODE_COLUMN]
  31. nrows = df.shape[0]
  32. print("loaded")
  33. tfidf_params = {'min_df': 5
  34. , 'max_df': 0.3
  35. , 'smooth_idf': True}
  36. meta = {'DATASET_PATH': DATASET_PATH
  37. ,'TFIDF_DIR': TFIDF_DIR
  38. ,'MODEL_DIR': MODEL_DIR
  39. ,'nrows': nrows
  40. ,'label': TAGS_TO_PREDICT
  41. ,'model': MODEL
  42. ,'graph_ver': GRAPH_VER
  43. ,'script_dir': SCRIPT_DIR
  44. ,'task': TASK}
  45. code_blocks_tfidf = tfidf_transform(code_blocks, tfidf_params, TFIDF_DIR)
  46. with dagshub.dagshub_logger() as logger:
  47. _, y, y_pred, metrics = get_metrics(code_blocks_tfidf, df[TAGS_TO_PREDICT], TAGS_TO_PREDICT, MODEL_DIR)
  48. logger.log_hyperparams(meta)
  49. logger.log_metrics(metrics)
  50. print("finished")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...