Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

keras_word_prediction_layer.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  1. import tensorflow as tf
  2. from tensorflow.python import keras
  3. from tensorflow.python.keras.layers import Layer
  4. import tensorflow.python.keras.backend as K
  5. from typing import Optional, List, Callable
  6. from functools import reduce
  7. from common import common
  8. class WordPredictionLayer(Layer):
  9. FilterType = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
  10. def __init__(self,
  11. top_k: int,
  12. index_to_word_table: tf.contrib.lookup.HashTable,
  13. predicted_words_filters: Optional[List[FilterType]] = None,
  14. **kwargs):
  15. kwargs['dtype'] = tf.string
  16. kwargs['trainable'] = False
  17. super(WordPredictionLayer, self).__init__(**kwargs)
  18. self.top_k = top_k
  19. self.index_to_word_table = index_to_word_table
  20. self.predicted_words_filters = predicted_words_filters
  21. def build(self, input_shape):
  22. if len(input_shape) != 2:
  23. raise ValueError("Input shape for WordPredictionLayer should be of 2 dimension.")
  24. super(WordPredictionLayer, self).build(input_shape)
  25. self.trainable = False
  26. def call(self, y_pred, **kwargs):
  27. y_pred.shape.assert_has_rank(2)
  28. top_k_pred_indices = tf.cast(tf.nn.top_k(y_pred, k=self.top_k).indices,
  29. dtype=self.index_to_word_table.key_dtype)
  30. predicted_target_words_strings = self.index_to_word_table.lookup(top_k_pred_indices)
  31. # apply given filter
  32. masks = []
  33. if self.predicted_words_filters is not None:
  34. masks = [fltr(top_k_pred_indices, predicted_target_words_strings) for fltr in self.predicted_words_filters]
  35. if masks:
  36. # assert all(mask.shape.assert_is_compatible_with(top_k_pred_indices) for mask in masks)
  37. legal_predicted_target_words_mask = reduce(tf.logical_and, masks)
  38. else:
  39. legal_predicted_target_words_mask = tf.cast(tf.ones_like(top_k_pred_indices), dtype=tf.bool)
  40. # the first legal predicted word is our prediction
  41. first_legal_predicted_target_word_mask = common.tf_get_first_true(legal_predicted_target_words_mask)
  42. first_legal_predicted_target_word_idx = tf.where(first_legal_predicted_target_word_mask)
  43. first_legal_predicted_word_string = tf.gather_nd(predicted_target_words_strings,
  44. first_legal_predicted_target_word_idx)
  45. prediction = tf.reshape(first_legal_predicted_word_string, [-1])
  46. return prediction
  47. def compute_output_shape(self, input_shape):
  48. return input_shape[0], # (batch,)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...