Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

keras_words_subtoken_metrics.py 6.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  1. import tensorflow as tf
  2. import tensorflow.keras.backend as K
  3. import abc
  4. from typing import Optional, Callable, List
  5. from functools import reduce
  6. from common import common
  7. class WordsSubtokenMetricBase(tf.metrics.Metric): # KIR
  8. FilterType = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
  9. def __init__(self,
  10. index_to_word_table: Optional[tf.lookup.StaticHashTable] = None,
  11. topk_predicted_words=None,
  12. predicted_words_filters: Optional[List[FilterType]] = None,
  13. subtokens_delimiter: str = '|', name=None, dtype=None):
  14. super(WordsSubtokenMetricBase, self).__init__(name=name, dtype=dtype)
  15. self.tp = self.add_weight('true_positives', shape=(), initializer=tf.zeros_initializer)
  16. self.fp = self.add_weight('false_positives', shape=(), initializer=tf.zeros_initializer)
  17. self.fn = self.add_weight('false_negatives', shape=(), initializer=tf.zeros_initializer)
  18. self.index_to_word_table = index_to_word_table
  19. self.topk_predicted_words = topk_predicted_words
  20. self.predicted_words_filters = predicted_words_filters
  21. self.subtokens_delimiter = subtokens_delimiter
  22. def _get_true_target_word_string(self, true_target_word):
  23. if self.index_to_word_table is None:
  24. return true_target_word
  25. true_target_word_index = tf.cast(true_target_word, dtype=self.index_to_word_table.key_dtype)
  26. return self.index_to_word_table.lookup(true_target_word_index)
  27. def update_state(self, true_target_word, predictions, sample_weight=None):
  28. """Accumulates true positive, false positive and false negative statistics."""
  29. if sample_weight is not None:
  30. raise NotImplemented("WordsSubtokenMetricBase with non-None `sample_weight` is not implemented.")
  31. # For each example in the batch we have:
  32. # (i) one ground true target word;
  33. # (ii) one predicted word (argmax y_hat)
  34. topk_predicted_words = predictions if self.topk_predicted_words is None else self.topk_predicted_words
  35. assert topk_predicted_words is not None
  36. predicted_word = self._get_prediction_from_topk(topk_predicted_words)
  37. true_target_word_string = self._get_true_target_word_string(true_target_word)
  38. true_target_word_string = tf.reshape(true_target_word_string, [-1])
  39. # We split each word into subtokens
  40. true_target_subwords = tf.compat.v1.string_split(true_target_word_string, sep=self.subtokens_delimiter)
  41. prediction_subwords = tf.compat.v1.string_split(predicted_word, sep=self.subtokens_delimiter)
  42. true_target_subwords = tf.sparse.to_dense(true_target_subwords, default_value='<PAD>')
  43. prediction_subwords = tf.sparse.to_dense(prediction_subwords, default_value='<PAD>')
  44. true_target_subwords_mask = tf.not_equal(true_target_subwords, '<PAD>')
  45. prediction_subwords_mask = tf.not_equal(prediction_subwords, '<PAD>')
  46. # Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens)
  47. # We use broadcast to calculate 2 lists difference with duplicates preserving.
  48. true_target_subwords = tf.expand_dims(true_target_subwords, -1)
  49. prediction_subwords = tf.expand_dims(prediction_subwords, -1)
  50. # Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens, 1)
  51. true_target_subwords__in__prediction_subwords = \
  52. tf.reduce_any(tf.equal(true_target_subwords, tf.transpose(prediction_subwords, perm=[0, 2, 1])), axis=2)
  53. prediction_subwords__in__true_target_subwords = \
  54. tf.reduce_any(tf.equal(prediction_subwords, tf.transpose(true_target_subwords, perm=[0, 2, 1])), axis=2)
  55. # Count ground true label subwords that exist in the predicted word.
  56. batch_true_positive = tf.reduce_sum(tf.cast(
  57. tf.logical_and(prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
  58. # Count ground true label subwords that don't exist in the predicted word.
  59. batch_false_positive = tf.reduce_sum(tf.cast(
  60. tf.logical_and(~prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
  61. # Count predicted word subwords that don't exist in the ground true label.
  62. batch_false_negative = tf.reduce_sum(tf.cast(
  63. tf.logical_and(~true_target_subwords__in__prediction_subwords, true_target_subwords_mask), tf.float32))
  64. self.tp.assign_add(batch_true_positive)
  65. self.fp.assign_add(batch_false_positive)
  66. self.fn.assign_add(batch_false_negative)
  67. def _get_prediction_from_topk(self, topk_predicted_words):
  68. # apply given filter
  69. masks = []
  70. if self.predicted_words_filters is not None:
  71. masks = [fltr(topk_predicted_words) for fltr in self.predicted_words_filters]
  72. if masks:
  73. # assert all(mask.shape.assert_is_compatible_with(top_k_pred_indices) for mask in masks)
  74. legal_predicted_target_words_mask = reduce(tf.logical_and, masks)
  75. else:
  76. legal_predicted_target_words_mask = tf.cast(tf.ones_like(topk_predicted_words), dtype=tf.bool)
  77. # the first legal predicted word is our prediction
  78. first_legal_predicted_target_word_mask = common.tf_get_first_true(legal_predicted_target_words_mask)
  79. first_legal_predicted_target_word_idx = tf.where(first_legal_predicted_target_word_mask)
  80. first_legal_predicted_word_string = tf.gather_nd(topk_predicted_words,
  81. first_legal_predicted_target_word_idx)
  82. prediction = tf.reshape(first_legal_predicted_word_string, [-1])
  83. return prediction
  84. @abc.abstractmethod
  85. def result(self):
  86. ...
  87. def reset_states(self):
  88. for v in self.variables:
  89. K.set_value(v, 0)
  90. class WordsSubtokenPrecisionMetric(WordsSubtokenMetricBase):
  91. def result(self):
  92. precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
  93. return precision
  94. class WordsSubtokenRecallMetric(WordsSubtokenMetricBase):
  95. def result(self):
  96. recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
  97. return recall
  98. class WordsSubtokenF1Metric(WordsSubtokenMetricBase):
  99. def result(self):
  100. recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
  101. precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
  102. f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall + K.epsilon())
  103. return f1
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...