Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

logreg_classifier.py 5.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  1. import pickle
  2. import argparse
  3. import json
  4. import sys, os
  5. import pandas as pd
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import scipy.stats
  9. from sklearn.linear_model import LogisticRegression
  10. from sklearn.multioutput import MultiOutputRegressor
  11. from sklearn.model_selection import train_test_split
  12. from sklearn.metrics import *
  13. import dagshub
  14. sys.path.append('./models_scripts/common/')
  15. from tools import *
  16. # def logreg_evaluate(df, code_blocks, TAG_TO_PREDICT):
  17. # code_blocks_tfidf = tfidf_fit_transform(code_blocks, tfidf_params, TFIDF_DIR)
  18. # X_train, X_test, y_train, y_test = train_test_split(code_blocks_tfidf, df[TAG_TO_PREDICT], test_size=0.25)
  19. # clf = LogisticRegression(random_state=421).fit(X_train, y_train)
  20. # print("inited the model")
  21. # pickle.dump(clf, open(MODEL_DIR, 'wb'))
  22. # print("saved the model")
  23. # y_pred = clf.predict(X_test)
  24. # accuracy = clf.score(X_test, y_test)
  25. # f1 = f1_score(y_pred, y_test, average='weighted')
  26. # print(f'Mean Accuracy {round(accuracy*100, 2)}%')
  27. # print(f'F1-score {round(f1*100, 2)}%')
  28. # errors = y_test - y_pred
  29. # plt.hist(errors)
  30. # # plot_precision_recall_curve(clf, X_test, y_test)
  31. # # plot_confusion_matrix(clf, X_test, y_test, values_format='d')
  32. # def mean_confidence_interval(data, confidence=0.95):
  33. # a = 1.0 * np.array(data)
  34. # n = len(a)
  35. # m, se = np.mean(a), scipy.stats.sem(a)
  36. # h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
  37. # return m, m-h, m+h
  38. # conf_interval = mean_confidence_interval(errors, 0.95)
  39. # print(conf_interval)
  40. # metrics = {'test_accuracy': accuracy
  41. # , 'test_f1_score': f1}
  42. # return metrics
  43. # def get_predictions(X, y, TAGS_TO_PREDICT, MODEL_DIR):
  44. # clf = pickle.load(open(MODEL_DIR, 'rb'))
  45. # # result = loaded_model.score(X, y)
  46. # y_pred = clf.predict(X)
  47. # accuracy = accuracy_score(y_pred, y)
  48. # f1 = f1_score(y_pred, y, average='weighted')
  49. # print(f'Mean Accuracy {round(accuracy*100, 2)}%')
  50. # print(f'F1-score {round(f1*100, 2)}%')
  51. # errors = y - y_pred
  52. # plt.hist(errors)
  53. # plot_precision_recall_curve(clf, X, y)
  54. # plot_confusion_matrix(clf, X, y, values_format='d')
  55. # def mean_confidence_interval(data, confidence=0.95):
  56. # a = 1.0 * np.array(data)
  57. # n = len(a)
  58. # m, se = np.mean(a), scipy.stats.sem(a)
  59. # h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
  60. # return m, m-h, m+h
  61. # conf_interval = mean_confidence_interval(errors, 0.95)
  62. # print(conf_interval)
  63. # metrics = {'test_accuracy': accuracy
  64. # , 'test_f1_score': f1}
  65. # return metrics
  66. def logreg_multioutput_evaluate(df, code_blocks, TAGS_TO_PREDICT):
  67. code_blocks_tfidf = tfidf_fit_transform(code_blocks, tfidf_params, TFIDF_DIR)
  68. print("tfifd-ed")
  69. X_train, X_test, Y_train, Y_test = train_test_split(code_blocks_tfidf, df[TAGS_TO_PREDICT], test_size=0.25)
  70. print("splitted to train and test")
  71. clf = MultiOutputRegressor(LogisticRegression(random_state=421)).fit(X_train, Y_train)
  72. print("trained the model")
  73. pickle.dump(clf, open(MODEL_DIR, 'wb'))
  74. print("saved the model")
  75. Y_pred = clf.predict(X_test)
  76. accuracy = clf.score(X_test, Y_test)
  77. f1 = f1_score(Y_pred, Y_test, average='weighted')
  78. print(f'Mean Accuracy {round(accuracy*100, 2)}%')
  79. print(f'F1-score {round(f1*100, 2)}%')
  80. # errors = Y_test - Y_pred
  81. # plt.hist(errors)
  82. # plot_precision_recall_curve(clf, X_test, Y_test)
  83. # plot_confusion_matrix(clf, X_test, Y_test, values_format='d')
  84. metrics = {'test_accuracy': accuracy
  85. , 'test_f1_score': f1}
  86. return metrics
  87. try:
  88. parser = argparse.ArgumentParser()
  89. parser.add_argument("GRAPH_VER", help="version of the graph you want regex to label your CSV with", type=int)
  90. parser.add_argument("DATASET_PATH", help="path to your input CSV", type=str)
  91. args = parser.parse_args()
  92. GRAPH_VER = args.GRAPH_VER
  93. DATASET_PATH = args.DATASET_PATH
  94. except:
  95. print('Got no arguments, taking default arguments from params.yaml')
  96. with open("params.yaml", 'r') as fd:
  97. params = yaml.safe_load(fd)
  98. GRAPH_VER = params['GRAPH_VER']
  99. DATASET_PATH = params['regex']['DATASET_PATH']
  100. # REPO_PATH = os.path.dirname(os.path.abspath(__file__)).replace('\\', '/') + '/'
  101. MODEL_DIR = './models/logreg_regex_graph_v{}.sav'.format(GRAPH_VER)
  102. TFIDF_DIR = './models/tfidf_logreg_graph_v{}.pickle'.format(GRAPH_VER)
  103. CODE_COLUMN = 'code_block'
  104. TAGS_TO_PREDICT = get_graph_vertices(GRAPH_VER)
  105. PREDICT_COL = 'pred_{}'.format(TAGS_TO_PREDICT)
  106. SCRIPT_DIR = 'logreg_classifier.ipynb'
  107. TASK = 'training LogReg'
  108. if __name__ == '__main__':
  109. df = load_data(DATASET_PATH)
  110. code_blocks = df[CODE_COLUMN]
  111. nrows = df.shape[0]
  112. print("loaded the data")
  113. tfidf_params = {'min_df': 5
  114. , 'max_df': 0.3
  115. , 'smooth_idf': True}
  116. data_meta = {'DATASET_PATH': DATASET_PATH
  117. ,'TFIDF_DIR': TFIDF_DIR
  118. ,'MODEL_DIR': MODEL_DIR
  119. ,'nrows': nrows
  120. ,'label': TAGS_TO_PREDICT
  121. ,'graph_ver': GRAPH_VER
  122. ,'script_dir': SCRIPT_DIR
  123. ,'task': TASK}
  124. with dagshub.dagshub_logger() as logger:
  125. metrics = logreg_multioutput_evaluate(df, code_blocks, TAGS_TO_PREDICT)
  126. logger.log_hyperparams(data_meta)
  127. logger.log_hyperparams(tfidf_params)
  128. logger.log_metrics(metrics)
  129. print("saved the dicts")
  130. print("finished")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...