Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evals.py 7.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
  1. import sys
  2. from sklearn import metrics
  3. import math
  4. import os
  5. from sklearn.metrics import auc
  6. from copy import deepcopy
  7. import numpy as np
  8. import warnings
  9. import time
  10. warnings.filterwarnings(action='ignore', category=DeprecationWarning)
  11. warnings.filterwarnings(action='ignore', category=RuntimeWarning)
  12. def ranking_precision_score(Y_true, Y_score, k=10):
  13. """Precision at rank k
  14. Parameters
  15. ----------
  16. y_true : array-like, shape = [n_samples]
  17. Ground truth (true relevance labels).
  18. y_score : array-like, shape = [n_samples]
  19. Predicted scores.
  20. k : int
  21. Rank.
  22. Returns
  23. -------
  24. precision @k : float
  25. """
  26. sum_prec = 0.
  27. n = len(Y_true)
  28. unique_Y = np.unique(Y_true)
  29. if len(unique_Y) > 2:
  30. raise ValueError("Only supported for two relevance levels.")
  31. pos_label = unique_Y[1]
  32. n_pos = np.sum(Y_true == pos_label, axis=1)
  33. order = np.argsort(Y_score, axis=1)[:, ::-1]
  34. Y_true = np.array([x[y] for x, y in zip(Y_true, order[:, :k])])
  35. n_relevant = np.sum(Y_true == pos_label, axis=1)
  36. cnt = k
  37. prec = np.divide(n_relevant.astype(float), cnt)
  38. return np.average(prec)
  39. def subset_accuracy(true_targets, predictions, per_sample=False, axis=0):
  40. result = np.all(true_targets == predictions, axis=axis)
  41. if not per_sample:
  42. result = np.mean(result)
  43. return result
  44. def hamming_loss(true_targets, predictions, per_sample=False, axis=0):
  45. result = np.mean(np.logical_xor(true_targets, predictions), axis=axis)
  46. if not per_sample:
  47. result = np.mean(result)
  48. return result
  49. def compute_tp_fp_fn(true_targets, predictions, axis=0):
  50. tp = np.sum(true_targets * predictions, axis=axis).astype('float32')
  51. fp = np.sum(np.logical_not(true_targets) * predictions,
  52. axis=axis).astype('float32')
  53. fn = np.sum(true_targets * np.logical_not(predictions),
  54. axis=axis).astype('float32')
  55. return (tp, fp, fn)
  56. def example_f1_score(true_targets, predictions, per_sample=False, axis=0):
  57. tp, fp, fn = compute_tp_fp_fn(true_targets, predictions, axis=axis)
  58. numerator = 2*tp
  59. denominator = (np.sum(true_targets, axis=axis).astype('float32') + np.sum(predictions, axis=axis).astype('float32'))
  60. zeros = np.where(denominator == 0)[0]
  61. denominator = np.delete(denominator, zeros)
  62. numerator = np.delete(numerator, zeros)
  63. example_f1 = numerator/denominator
  64. if per_sample:
  65. f1 = example_f1
  66. else:
  67. f1 = np.mean(example_f1)
  68. return f1
  69. def f1_score_from_stats(tp, fp, fn, average='micro'):
  70. assert len(tp) == len(fp)
  71. assert len(fp) == len(fn)
  72. if average not in set(['micro', 'macro']):
  73. raise ValueError("Specify micro or macro")
  74. if average == 'micro':
  75. f1 = 2*np.sum(tp) / \
  76. float(2*np.sum(tp) + np.sum(fp) + np.sum(fn))
  77. elif average == 'macro':
  78. def safe_div(a, b):
  79. """ ignore / 0, div0( [-1, 0, 1], 0 ) -> [0, 0, 0] """
  80. with np.errstate(divide='ignore', invalid='ignore'):
  81. c = np.true_divide(a, b)
  82. return c[np.isfinite(c)]
  83. f1 = np.mean(safe_div(2*tp, 2*tp + fp + fn + 1e-6))
  84. return f1
  85. def f1_score(true_targets, predictions, average='micro', axis=0):
  86. """
  87. average: str
  88. 'micro' or 'macro'
  89. axis: 0 or 1
  90. label axis
  91. """
  92. if average not in set(['micro', 'macro']):
  93. raise ValueError("Specify micro or macro")
  94. tp, fp, fn = compute_tp_fp_fn(true_targets, predictions, axis=axis)
  95. f1 = f1_score_from_stats(tp, fp, fn, average=average)
  96. return f1
  97. def compute_fdr(all_targets, all_predictions, fdr_cutoff=0.5):
  98. fdr_array = []
  99. for i in range(all_targets.shape[1]):
  100. try:
  101. precision, recall, thresholds = metrics.precision_recall_curve(all_targets[:, i], all_predictions[:, i], pos_label=1)
  102. fdr = 1- precision
  103. cutoff_index = next(i for i, x in enumerate(fdr) if x <= fdr_cutoff)
  104. fdr_at_cutoff = recall[cutoff_index]
  105. if not math.isnan(fdr_at_cutoff):
  106. fdr_array.append(np.nan_to_num(fdr_at_cutoff))
  107. except:
  108. pass
  109. fdr_array = np.array(fdr_array)
  110. mean_fdr = np.mean(fdr_array)
  111. median_fdr = np.median(fdr_array)
  112. var_fdr = np.var(fdr_array)
  113. return mean_fdr, median_fdr, var_fdr, fdr_array
  114. def compute_aupr(all_targets, all_predictions):
  115. aupr_array = []
  116. for i in range(all_targets.shape[1]):
  117. precision, recall, thresholds = metrics.precision_recall_curve(all_targets[:, i], all_predictions[:, i], pos_label=1)
  118. auPR = metrics.auc(recall, precision)
  119. if not math.isnan(auPR):
  120. aupr_array.append(np.nan_to_num(auPR))
  121. aupr_array = np.array(aupr_array)
  122. mean_aupr = np.mean(aupr_array)
  123. median_aupr = np.median(aupr_array)
  124. var_aupr = np.var(aupr_array)
  125. return mean_aupr, median_aupr, var_aupr, aupr_array
  126. def compute_auc(all_targets, all_predictions):
  127. auc_array = []
  128. for i in range(all_targets.shape[1]):
  129. try:
  130. auROC = metrics.roc_auc_score(all_targets[:, i], all_predictions[:, i])
  131. auc_array.append(auROC)
  132. except ValueError:
  133. pass
  134. auc_array = np.array(auc_array)
  135. mean_auc = np.mean(auc_array)
  136. median_auc = np.median(auc_array)
  137. var_auc = np.var(auc_array)
  138. return mean_auc, median_auc, var_auc, auc_array
  139. def compute_metrics(predictions, targets, threshold, all_metrics=True):
  140. all_targets = deepcopy(targets)
  141. all_predictions = deepcopy(predictions)
  142. if all_metrics:
  143. meanAUC, medianAUC, varAUC, allAUC = compute_auc(all_targets, all_predictions)
  144. meanAUPR, medianAUPR, varAUPR, allAUPR = compute_aupr(all_targets, all_predictions)
  145. meanFDR, medianFDR, varFDR, allFDR = compute_fdr(all_targets, all_predictions)
  146. else:
  147. meanAUC, medianAUC, varAUC, allAUC = 0, 0, 0, 0
  148. meanAUPR, medianAUPR, varAUPR, allAUPR = 0, 0, 0, 0
  149. meanFDR, medianFDR, varFDR, allFDR = 0, 0, 0, 0
  150. # p_at_1 = 0.
  151. # p_at_3 = 0.
  152. # p_at_5 = 0.
  153. # p_at_1 = ranking_precision_score(Y_true=all_targets, Y_score=all_predictions, k=1)
  154. # p_at_3 = ranking_precision_score(Y_true=all_targets, Y_score=all_predictions, k=3)
  155. # p_at_5 = ranking_precision_score(Y_true=all_targets, Y_score=all_predictions, k=5)
  156. optimal_threshold = threshold
  157. all_predictions[all_predictions < optimal_threshold] = 0
  158. all_predictions[all_predictions >= optimal_threshold] = 1
  159. acc_ = list(subset_accuracy(all_targets, all_predictions, axis=1, per_sample=True))
  160. hl_ = list(hamming_loss(all_targets, all_predictions, axis=1, per_sample=True))
  161. exf1_ = list(example_f1_score(all_targets, all_predictions, axis=1, per_sample=True))
  162. ACC = np.mean(acc_)
  163. hl = np.mean(hl_)
  164. HA = 1 - hl
  165. ebF1 = np.mean(exf1_)
  166. tp, fp, fn = compute_tp_fp_fn(all_targets, all_predictions, axis=0)
  167. miF1 = f1_score_from_stats(tp, fp, fn, average='micro')
  168. maF1 = f1_score_from_stats(tp, fp, fn, average='macro')
  169. metrics_dict = {}
  170. metrics_dict['ACC'] = ACC
  171. metrics_dict['HA'] = HA
  172. metrics_dict['ebF1'] = ebF1
  173. metrics_dict['miF1'] = miF1
  174. metrics_dict['maF1'] = maF1
  175. metrics_dict['meanAUC'] = meanAUC
  176. metrics_dict['medianAUC'] = medianAUC
  177. metrics_dict['varAUC'] = varAUC
  178. metrics_dict['allAUC'] = allAUC
  179. metrics_dict['meanAUPR'] = meanAUPR
  180. metrics_dict['medianAUPR'] = medianAUPR
  181. metrics_dict['varAUPR'] = varAUPR
  182. metrics_dict['allAUPR'] = allAUPR
  183. metrics_dict['meanFDR'] = meanFDR
  184. metrics_dict['medianFDR'] = medianFDR
  185. metrics_dict['varFDR'] = varFDR
  186. metrics_dict['allFDR'] = allFDR
  187. # metrics_dict['p_at_1'] = p_at_1
  188. # metrics_dict['p_at_3'] = p_at_3
  189. # metrics_dict['p_at_5'] = p_at_5
  190. return metrics_dict
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...