Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_base.py 7.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
  1. import numpy as np
  2. import abc
  3. import os
  4. from typing import NamedTuple, Optional, List, Dict, Tuple, Iterable
  5. from common import common
  6. from vocabularies import Code2VecVocabs, VocabType
  7. from config import Config
  8. class ModelEvaluationResults(NamedTuple):
  9. topk_acc: float
  10. subtoken_precision: float
  11. subtoken_recall: float
  12. subtoken_f1: float
  13. loss: Optional[float] = None
  14. def __str__(self):
  15. res_str = 'topk_acc: {topk_acc}, precision: {precision}, recall: {recall}, F1: {f1}'.format(
  16. topk_acc=self.topk_acc,
  17. precision=self.subtoken_precision,
  18. recall=self.subtoken_recall,
  19. f1=self.subtoken_f1)
  20. if self.loss is not None:
  21. res_str = ('loss: {}, '.format(self.loss)) + res_str
  22. return res_str
  23. class ModelPredictionResults(NamedTuple):
  24. original_name: str
  25. topk_predicted_words: np.ndarray
  26. topk_predicted_words_scores: np.ndarray
  27. attention_per_context: Dict[Tuple[str, str, str], float]
  28. code_vector: Optional[np.ndarray] = None
  29. class Code2VecModelBase(abc.ABC):
  30. def __init__(self, config: Config):
  31. self.config = config
  32. self.config.verify()
  33. self._log_creating_model()
  34. if not config.RELEASE:
  35. self._init_num_of_examples()
  36. self._log_model_configuration()
  37. self.vocabs = Code2VecVocabs(config)
  38. self.vocabs.target_vocab.get_index_to_word_lookup_table() # just to initialize it (if not already initialized)
  39. self._load_or_create_inner_model()
  40. self._initialize()
  41. def _log_creating_model(self):
  42. self.log('')
  43. self.log('')
  44. self.log('---------------------------------------------------------------------')
  45. self.log('---------------------------------------------------------------------')
  46. self.log('---------------------- Creating code2vec model ----------------------')
  47. self.log('---------------------------------------------------------------------')
  48. self.log('---------------------------------------------------------------------')
  49. def _log_model_configuration(self):
  50. self.log('---------------------------------------------------------------------')
  51. self.log('----------------- Configuration - Hyper Parameters ------------------')
  52. longest_param_name_len = max(len(param_name) for param_name, _ in self.config)
  53. for param_name, param_val in self.config:
  54. self.log('{name: <{name_len}}{val}'.format(
  55. name=param_name, val=param_val, name_len=longest_param_name_len+2))
  56. self.log('---------------------------------------------------------------------')
  57. @property
  58. def logger(self):
  59. return self.config.get_logger()
  60. def log(self, msg):
  61. self.logger.info(msg)
  62. def _init_num_of_examples(self):
  63. self.log('Checking number of examples ...')
  64. if self.config.is_training:
  65. self.config.NUM_TRAIN_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.train_data_path)
  66. self.log(' Number of train examples: {}'.format(self.config.NUM_TRAIN_EXAMPLES))
  67. if self.config.is_testing:
  68. self.config.NUM_TEST_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.TEST_DATA_PATH)
  69. self.log(' Number of test examples: {}'.format(self.config.NUM_TEST_EXAMPLES))
  70. @staticmethod
  71. def _get_num_of_examples_for_dataset(dataset_path: str) -> int:
  72. dataset_num_examples_file_path = dataset_path + '.num_examples'
  73. if os.path.isfile(dataset_num_examples_file_path):
  74. with open(dataset_num_examples_file_path, 'r') as file:
  75. num_examples_in_dataset = int(file.readline())
  76. else:
  77. num_examples_in_dataset = common.count_lines_in_file(dataset_path)
  78. with open(dataset_num_examples_file_path, 'w') as file:
  79. file.write(str(num_examples_in_dataset))
  80. return num_examples_in_dataset
  81. def load_or_build(self):
  82. self.vocabs = Code2VecVocabs(self.config)
  83. self._load_or_create_inner_model()
  84. def save(self, model_save_path=None):
  85. if model_save_path is None:
  86. model_save_path = self.config.MODEL_SAVE_PATH
  87. model_save_dir = '/'.join(model_save_path.split('/')[:-1])
  88. if not os.path.isdir(model_save_dir):
  89. os.makedirs(model_save_dir, exist_ok=True)
  90. self.vocabs.save(self.config.get_vocabularies_path_from_model_path(model_save_path))
  91. self._save_inner_model(model_save_path)
  92. def _write_code_vectors(self, file, code_vectors):
  93. for vec in code_vectors:
  94. file.write(' '.join(map(str, vec)) + '\n')
  95. def _get_attention_weight_per_context(
  96. self, path_source_strings: Iterable[str], path_strings: Iterable[str], path_target_strings: Iterable[str],
  97. attention_weights: Iterable[float]) -> Dict[Tuple[str, str, str], float]:
  98. attention_weights = np.squeeze(attention_weights, axis=-1) # (max_contexts, )
  99. attention_per_context: Dict[Tuple[str, str, str], float] = {}
  100. # shape of path_source_strings, path_strings, path_target_strings, attention_weights is (max_contexts, )
  101. # iterate over contexts
  102. for path_source, path, path_target, weight in \
  103. zip(path_source_strings, path_strings, path_target_strings, attention_weights):
  104. string_context_triplet = (common.binary_to_string(path_source),
  105. common.binary_to_string(path),
  106. common.binary_to_string(path_target))
  107. attention_per_context[string_context_triplet] = weight
  108. return attention_per_context
  109. def close_session(self):
  110. # can be overridden by the implementation model class.
  111. # default implementation just does nothing.
  112. pass
  113. @abc.abstractmethod
  114. def train(self):
  115. ...
  116. @abc.abstractmethod
  117. def evaluate(self) -> Optional[ModelEvaluationResults]:
  118. ...
  119. @abc.abstractmethod
  120. def predict(self, predict_data_lines: Iterable[str]) -> List[ModelPredictionResults]:
  121. ...
  122. @abc.abstractmethod
  123. def _save_inner_model(self, path):
  124. ...
  125. def _load_or_create_inner_model(self):
  126. if self.config.is_loading:
  127. self._load_inner_model()
  128. else:
  129. self._create_inner_model()
  130. @abc.abstractmethod
  131. def _load_inner_model(self):
  132. ...
  133. def _create_inner_model(self):
  134. # can be overridden by the implementation model class.
  135. # default implementation just does nothing.
  136. pass
  137. def _initialize(self):
  138. # can be overridden by the implementation model class.
  139. # default implementation just does nothing.
  140. pass
  141. @abc.abstractmethod
  142. def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
  143. ...
  144. def save_word2vec_format(self, dest_save_path: str, vocab_type: VocabType):
  145. if vocab_type not in VocabType:
  146. raise ValueError('`vocab_type` should be `VocabType.Token`, `VocabType.Target` or `VocabType.Path`.')
  147. vocab_embedding_matrix = self._get_vocab_embedding_as_np_array(vocab_type)
  148. index_to_word = self.vocabs.get(vocab_type).index_to_word
  149. with open(dest_save_path, 'w') as words_file:
  150. common.save_word2vec_file(words_file, index_to_word, vocab_embedding_matrix)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...