Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

keras_model.py 23 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras.layers import Input, Embedding, Concatenate, Dropout, TimeDistributed, Dense
  4. from tensorflow.keras.callbacks import Callback
  5. import tensorflow.keras.backend as K
  6. from tensorflow.keras.metrics import sparse_top_k_categorical_accuracy
  7. from path_context_reader import PathContextReader, ModelInputTensorsFormer, ReaderInputTensors, EstimatorAction
  8. import os
  9. import numpy as np
  10. from functools import partial
  11. from typing import List, Optional, Iterable, Union, Callable, Dict
  12. from collections import namedtuple
  13. import time
  14. import datetime
  15. from vocabularies import VocabType
  16. from keras_attention_layer import AttentionLayer
  17. from keras_topk_word_predictions_layer import TopKWordPredictionsLayer
  18. from keras_words_subtoken_metrics import WordsSubtokenPrecisionMetric, WordsSubtokenRecallMetric, WordsSubtokenF1Metric
  19. from config import Config
  20. from common import common
  21. from model_base import Code2VecModelBase, ModelEvaluationResults, ModelPredictionResults
  22. from keras_checkpoint_saver_callback import ModelTrainingStatus, ModelTrainingStatusTrackerCallback,\
  23. ModelCheckpointSaverCallback, MultiBatchCallback, ModelTrainingProgressLoggerCallback
  24. class Code2VecModel(Code2VecModelBase):
  25. def __init__(self, config: Config):
  26. self.keras_train_model: Optional[keras.Model] = None
  27. self.keras_eval_model: Optional[keras.Model] = None
  28. self.int_lev_model: Optional[keras.Model] = None
  29. self.keras_model_predict_function: Optional[K.GraphExecutionFunction] = None
  30. self.training_status: ModelTrainingStatus = ModelTrainingStatus()
  31. self._checkpoint: Optional[tf.train.Checkpoint] = None
  32. self._checkpoint_manager: Optional[tf.train.CheckpointManager] = None
  33. super(Code2VecModel, self).__init__(config)
  34. def _create_keras_model(self):
  35. # Each input sample consists of a bag of x`MAX_CONTEXTS` tuples (source_terminal, path, target_terminal).
  36. # The valid mask indicates for each context whether it actually exists or it is just a padding.
  37. path_source_token_input = Input((self.config.MAX_CONTEXTS,), dtype=tf.int32)
  38. path_input = Input((self.config.MAX_CONTEXTS,), dtype=tf.int32)
  39. path_target_token_input = Input((self.config.MAX_CONTEXTS,), dtype=tf.int32)
  40. context_valid_mask = Input((self.config.MAX_CONTEXTS,))
  41. # Input paths are indexes, we embed these here.
  42. paths_embedded = Embedding(
  43. self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE, name='path_embedding')(path_input)
  44. # Input terminals are indexes, we embed these here.
  45. token_embedding_shared_layer = Embedding(
  46. self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE, name='token_embedding')
  47. path_source_token_embedded = token_embedding_shared_layer(path_source_token_input)
  48. path_target_token_embedded = token_embedding_shared_layer(path_target_token_input)
  49. # `Context` is a concatenation of the 2 terminals & path embedding.
  50. # Each context is a vector of size 3 * EMBEDDINGS_SIZE.
  51. context_embedded = Concatenate()([path_source_token_embedded, paths_embedded, path_target_token_embedded])
  52. context_embedded = Dropout(1 - self.config.DROPOUT_KEEP_RATE)(context_embedded)
  53. # Lets get dense: Apply a dense layer for each context vector (using same weights for all of the context).
  54. context_after_dense = TimeDistributed(
  55. Dense(self.config.CODE_VECTOR_SIZE, use_bias=False, activation='tanh'))(context_embedded)
  56. # The final code vectors are received by applying attention to the "densed" context vectors.
  57. code_vectors, attention_weights = AttentionLayer(name='attention')(
  58. [context_after_dense, context_valid_mask])
  59. # "Decode": Now we use another dense layer to get the target word embedding from each code vector.
  60. target_index = Dense(
  61. self.vocabs.target_vocab.size, use_bias=False, activation='softmax', name='target_index')(code_vectors)
  62. # Wrap the layers into a Keras model, using our subtoken-metrics and the CE loss.
  63. inputs = [path_source_token_input, path_input, path_target_token_input, context_valid_mask]
  64. self.keras_train_model = keras.Model(inputs=inputs, outputs=target_index)
  65. # my stuff for embeddings
  66. layer_name = 'attention'
  67. self.int_lev_model = keras.Model(inputs=self.keras_train_model.input, \
  68. outputs=self.keras_train_model.get_layer(layer_name).output)
  69. # Actual target word predictions (as strings). Used as a second output layer.
  70. # Used for predict() and for the evaluation metrics calculations.
  71. topk_predicted_words, topk_predicted_words_scores = TopKWordPredictionsLayer(
  72. self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION,
  73. self.vocabs.target_vocab.get_index_to_word_lookup_table(),
  74. name='target_string')(target_index)
  75. # We use another dedicated Keras model for evaluation.
  76. # The evaluation model outputs the `topk_predicted_words` as a 2nd output.
  77. # The separation between train and eval models is for efficiency.
  78. self.keras_eval_model = keras.Model(
  79. inputs=inputs, outputs=[target_index, topk_predicted_words], name="code2vec-keras-model")
  80. # We use another dedicated Keras function to produce predictions.
  81. # It have additional outputs than the original model.
  82. # It is based on the trained layers of the original model and uses their weights.
  83. predict_outputs = tuple(KerasPredictionModelOutput(
  84. target_index=target_index, code_vectors=code_vectors, attention_weights=attention_weights,
  85. topk_predicted_words=topk_predicted_words, topk_predicted_words_scores=topk_predicted_words_scores))
  86. self.keras_model_predict_function = K.function(inputs=inputs, outputs=predict_outputs)
  87. def _create_metrics_for_keras_eval_model(self) -> Dict[str, List[Union[Callable, keras.metrics.Metric]]]:
  88. top_k_acc_metrics = []
  89. for k in range(1, self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION + 1):
  90. top_k_acc_metric = partial(
  91. sparse_top_k_categorical_accuracy, k=k)
  92. top_k_acc_metric.__name__ = 'top{k}_acc'.format(k=k)
  93. top_k_acc_metrics.append(top_k_acc_metric)
  94. predicted_words_filters = [
  95. lambda word_strings: tf.not_equal(word_strings, self.vocabs.target_vocab.special_words.OOV),
  96. lambda word_strings: tf.strings.regex_full_match(word_strings, r'^[a-zA-Z\|]+$')
  97. ]
  98. words_subtokens_metrics = [
  99. WordsSubtokenPrecisionMetric(predicted_words_filters=predicted_words_filters, name='subtoken_precision'),
  100. WordsSubtokenRecallMetric(predicted_words_filters=predicted_words_filters, name='subtoken_recall'),
  101. WordsSubtokenF1Metric(predicted_words_filters=predicted_words_filters, name='subtoken_f1')
  102. ]
  103. return {'target_index': top_k_acc_metrics, 'target_string': words_subtokens_metrics}
  104. @classmethod
  105. def _create_optimizer(cls):
  106. return tf.optimizers.Adam()
  107. def _compile_keras_model(self, optimizer=None):
  108. if optimizer is None:
  109. optimizer = self.keras_train_model.optimizer
  110. if optimizer is None:
  111. optimizer = self._create_optimizer()
  112. def zero_loss(true_word, topk_predictions):
  113. return tf.constant(0.0, shape=(), dtype=tf.float32)
  114. self.keras_train_model.compile(
  115. loss='sparse_categorical_crossentropy',
  116. optimizer=optimizer)
  117. self.keras_eval_model.compile(
  118. loss={'target_index': 'sparse_categorical_crossentropy', 'target_string': zero_loss},
  119. optimizer=optimizer,
  120. metrics=self._create_metrics_for_keras_eval_model())
  121. def _create_data_reader(self, estimator_action: EstimatorAction, repeat_endlessly: bool = False):
  122. return PathContextReader(
  123. vocabs=self.vocabs,
  124. config=self.config,
  125. model_input_tensors_former=_KerasModelInputTensorsFormer(estimator_action=estimator_action),
  126. estimator_action=estimator_action,
  127. repeat_endlessly=repeat_endlessly)
  128. def _create_train_callbacks(self) -> List[Callback]:
  129. # TODO: do we want to use early stopping? if so, use the right chechpoint manager and set the correct
  130. # `monitor` quantity (example: monitor='val_acc', mode='max')
  131. keras_callbacks = [
  132. ModelTrainingStatusTrackerCallback(self.training_status),
  133. ModelTrainingProgressLoggerCallback(self.config, self.training_status),
  134. ]
  135. if self.config.is_saving:
  136. keras_callbacks.append(ModelCheckpointSaverCallback(
  137. self, self.config.SAVE_EVERY_EPOCHS, self.logger))
  138. if self.config.is_testing:
  139. keras_callbacks.append(ModelEvaluationCallback(self))
  140. if self.config.USE_TENSORBOARD:
  141. log_dir = "logs/scalars/train_" + common.now_str()
  142. tensorboard_callback = keras.callbacks.TensorBoard(
  143. log_dir=log_dir,
  144. update_freq=self.config.NUM_BATCHES_TO_LOG_PROGRESS * self.config.TRAIN_BATCH_SIZE)
  145. keras_callbacks.append(tensorboard_callback)
  146. return keras_callbacks
  147. def train(self):
  148. # initialize the input pipeline reader
  149. train_data_input_reader = self._create_data_reader(estimator_action=EstimatorAction.Train)
  150. training_history = self.keras_train_model.fit(
  151. train_data_input_reader.get_dataset(),
  152. steps_per_epoch=self.config.train_steps_per_epoch,
  153. epochs=self.config.NUM_TRAIN_EPOCHS,
  154. initial_epoch=self.training_status.nr_epochs_trained,
  155. verbose=self.config.VERBOSE_MODE,
  156. callbacks=self._create_train_callbacks())
  157. self.log(training_history)
  158. def evaluate(self) -> Optional[ModelEvaluationResults]:
  159. val_data_input_reader = self._create_data_reader(estimator_action=EstimatorAction.Evaluate)
  160. eval_res = self.keras_eval_model.evaluate(
  161. val_data_input_reader.get_dataset(),
  162. steps=self.config.test_steps,
  163. verbose=self.config.VERBOSE_MODE)
  164. k = self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION
  165. with open('log.txt', 'w') as log_output_file:
  166. log_output_file.write(str(eval_res) + '\n')
  167. return ModelEvaluationResults(
  168. topk_acc=eval_res[3:k+3],
  169. subtoken_precision=eval_res[k+3],
  170. subtoken_recall=eval_res[k+4],
  171. subtoken_f1=eval_res[k+5],
  172. loss=eval_res[1]
  173. )
  174. def predict(self, predict_data_rows: Iterable[str]) -> List[ModelPredictionResults]:
  175. predict_input_reader = self._create_data_reader(estimator_action=EstimatorAction.Predict)
  176. input_iterator = predict_input_reader.process_and_iterate_input_from_data_lines(predict_data_rows)
  177. all_model_prediction_results = []
  178. for input_row in input_iterator:
  179. # perform the actual prediction and get raw results.
  180. input_for_predict = input_row[0][:4] # we want only the relevant input vectors (w.o. the targets).
  181. prediction_results = self.keras_model_predict_function(input_for_predict)
  182. embedding = self.int_lev_model.predict(input_for_predict)
  183. print(np.array(embedding[0][0]), file=open('cd2vec/EMBEDDINGS.txt', 'w'))
  184. # make `input_row` and `prediction_results` easy to read (by accessing named fields).
  185. prediction_results = KerasPredictionModelOutput(
  186. *common.squeeze_single_batch_dimension_for_np_arrays(prediction_results))
  187. input_row = _KerasModelInputTensorsFormer(
  188. estimator_action=EstimatorAction.Predict).from_model_input_form(input_row)
  189. input_row = ReaderInputTensors(*common.squeeze_single_batch_dimension_for_np_arrays(input_row))
  190. # calculate the attention weight for each context
  191. attention_per_context = self._get_attention_weight_per_context(
  192. path_source_strings=input_row.path_source_token_strings,
  193. path_strings=input_row.path_strings,
  194. path_target_strings=input_row.path_target_token_strings,
  195. attention_weights=prediction_results.attention_weights
  196. )
  197. # store the calculated prediction results in the wanted format.
  198. model_prediction_results = ModelPredictionResults(
  199. original_name=common.binary_to_string(input_row.target_string.item()),
  200. topk_predicted_words=common.binary_to_string_list(prediction_results.topk_predicted_words),
  201. topk_predicted_words_scores=prediction_results.topk_predicted_words_scores,
  202. attention_per_context=attention_per_context,
  203. code_vector=prediction_results.code_vectors)
  204. all_model_prediction_results.append(model_prediction_results)
  205. return all_model_prediction_results
  206. def _save_inner_model(self, path):
  207. if self.config.RELEASE:
  208. self.keras_train_model.save_weights(self.config.get_model_weights_path(path))
  209. else:
  210. self._get_checkpoint_manager().save(checkpoint_number=self.training_status.nr_epochs_trained)
  211. def _create_inner_model(self):
  212. self._create_keras_model()
  213. self._compile_keras_model()
  214. self.keras_train_model.summary(print_fn=self.log)
  215. def _load_inner_model(self):
  216. self._create_keras_model()
  217. self._compile_keras_model()
  218. # when loading the model for further training, we must use the full saved model file (not just weights).
  219. # we load the entire model if we must to or if there is no model weights file to load.
  220. must_use_entire_model = self.config.is_training
  221. entire_model_exists = os.path.exists(self.config.entire_model_load_path)
  222. model_weights_exist = os.path.exists(self.config.model_weights_load_path)
  223. use_full_model = must_use_entire_model or not model_weights_exist
  224. if must_use_entire_model and not entire_model_exists:
  225. raise ValueError(
  226. "There is no model at path `{model_file_path}`. When loading the model for further training, "
  227. "we must use an entire saved model file (not just weights).".format(
  228. model_file_path=self.config.entire_model_load_path))
  229. if not entire_model_exists and not model_weights_exist:
  230. raise ValueError(
  231. "There is no entire model to load at path `{entire_model_path}`, "
  232. "and there is no model weights file to load at path `{model_weights_path}`.".format(
  233. entire_model_path=self.config.entire_model_load_path,
  234. model_weights_path=self.config.model_weights_load_path))
  235. if use_full_model:
  236. self.log('Loading entire model from path `{}`.'.format(self.config.entire_model_load_path))
  237. latest_checkpoint = tf.train.latest_checkpoint(self.config.entire_model_load_path)
  238. if latest_checkpoint is None:
  239. raise ValueError("Failed to load model: Model latest checkpoint is not found.")
  240. self.log('Loading latest checkpoint `{}`.'.format(latest_checkpoint))
  241. status = self._get_checkpoint().restore(latest_checkpoint)
  242. status.initialize_or_restore()
  243. # FIXME: are we sure we have to re-compile here? I turned it off to save the optimizer state
  244. # self._compile_keras_model() # We have to re-compile because we also recovered the `tf.train.AdamOptimizer`.
  245. self.training_status.nr_epochs_trained = int(latest_checkpoint.split('-')[-1])
  246. else:
  247. # load the "released" model (only the weights).
  248. self.log('Loading model weights from path `{}`.'.format(self.config.model_weights_load_path))
  249. self.keras_train_model.load_weights(self.config.model_weights_load_path)
  250. self.keras_train_model.summary(print_fn=self.log)
  251. def _get_checkpoint(self):
  252. assert self.keras_train_model is not None and self.keras_train_model.optimizer is not None
  253. if self._checkpoint is None:
  254. # TODO: we would like to save (& restore) the `nr_epochs_trained`.
  255. self._checkpoint = tf.train.Checkpoint(
  256. # nr_epochs_trained=tf.Variable(self.training_status.nr_epochs_trained, name='nr_epochs_trained'),
  257. optimizer=self.keras_train_model.optimizer, model=self.keras_train_model)
  258. return self._checkpoint
  259. def _get_checkpoint_manager(self):
  260. if self._checkpoint_manager is None:
  261. self._checkpoint_manager = tf.train.CheckpointManager(
  262. self._get_checkpoint(), self.config.entire_model_save_path,
  263. max_to_keep=self.config.MAX_TO_KEEP)
  264. return self._checkpoint_manager
  265. # my embeddings
  266. #def _get_layer(self, layer_name):
  267. # res = self.int_lev_model.predict()
  268. #my_get_layer = K.function([self.keras_train_model.layers[0].input, K.learning_phase()],
  269. #[self.keras_train_model.layers[9].output])
  270. # weight = K.print_tensor(self.keras_train_model.get_layer(layer_name).output[0])
  271. #layer_output = my_get_layer([x, 1])[0]
  272. #return layer_output
  273. def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
  274. assert vocab_type in VocabType
  275. vocab_type_to_embedding_layer_mapping = {
  276. VocabType.Target: 'target_index',
  277. VocabType.Token: 'token_embedding',
  278. VocabType.Path: 'path_embedding'
  279. }
  280. embedding_layer_name = vocab_type_to_embedding_layer_mapping[vocab_type]
  281. weight = np.array(self.keras_train_model.get_layer(embedding_layer_name).get_weights()[0])
  282. assert len(weight.shape) == 2
  283. # token, path have an actual `Embedding` layers, but target have just a `Dense` layer.
  284. # hence, transpose the weight when necessary.
  285. assert self.vocabs.get(vocab_type).size in weight.shape
  286. if self.vocabs.get(vocab_type).size != weight.shape[0]:
  287. weight = np.transpose(weight)
  288. return weight
  289. def _create_lookup_tables(self):
  290. PathContextReader.create_needed_vocabs_lookup_tables(self.vocabs)
  291. self.log('Lookup tables created.')
  292. def _initialize(self):
  293. self._create_lookup_tables()
  294. class ModelEvaluationCallback(MultiBatchCallback):
  295. """
  296. This callback is passed to the `model.fit()` call.
  297. It is responsible to trigger model evaluation during the training.
  298. The reason we use a callback and not just passing validation data to `model.fit()` is because:
  299. (i) the training model is different than the evaluation model for efficiency considerations;
  300. (ii) we want to control the logging format;
  301. (iii) we want the evaluation to occur once per 1K batches (rather than only once per epoch).
  302. """
  303. def __init__(self, code2vec_model: 'Code2VecModel'):
  304. self.code2vec_model = code2vec_model
  305. self.avg_eval_duration: Optional[int] = None
  306. super(ModelEvaluationCallback, self).__init__(self.code2vec_model.config.NUM_TRAIN_BATCHES_TO_EVALUATE)
  307. def on_epoch_end(self, epoch, logs=None):
  308. self.perform_evaluation()
  309. def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
  310. self.perform_evaluation()
  311. def perform_evaluation(self):
  312. if self.avg_eval_duration is None:
  313. self.code2vec_model.log('Evaluating...')
  314. else:
  315. self.code2vec_model.log('Evaluating... (takes ~{})'.format(
  316. str(datetime.timedelta(seconds=int(self.avg_eval_duration)))))
  317. eval_start_time = time.time()
  318. evaluation_results = self.code2vec_model.evaluate()
  319. eval_duration = time.time() - eval_start_time
  320. if self.avg_eval_duration is None:
  321. self.avg_eval_duration = eval_duration
  322. else:
  323. self.avg_eval_duration = eval_duration * 0.5 + self.avg_eval_duration * 0.5
  324. self.code2vec_model.log('Done evaluating (took {}). Evaluation results:'.format(
  325. str(datetime.timedelta(seconds=int(eval_duration)))))
  326. self.code2vec_model.log(
  327. ' loss: {loss:.4f}, f1: {f1:.4f}, recall: {recall:.4f}, precision: {precision:.4f}'.format(
  328. loss=evaluation_results.loss, f1=evaluation_results.subtoken_f1,
  329. recall=evaluation_results.subtoken_recall, precision=evaluation_results.subtoken_precision))
  330. top_k_acc_formated = ['top{}: {:.4f}'.format(i, acc) for i, acc in enumerate(evaluation_results.topk_acc, start=1)]
  331. for top_k_acc_chunk in common.chunks(top_k_acc_formated, 5):
  332. self.code2vec_model.log(' ' + (', '.join(top_k_acc_chunk)))
  333. class _KerasModelInputTensorsFormer(ModelInputTensorsFormer):
  334. """
  335. An instance of this class is passed to the reader in order to help the reader to construct the input
  336. in the form that the model expects to receive it.
  337. This class also enables conveniently & clearly access input parts by their field names.
  338. eg: 'tensors.path_indices' instead if 'tensors[1]'.
  339. This allows the input tensors to be passed as pure tuples along the computation graph, while the
  340. python functions that construct the graph can easily (and clearly) access tensors.
  341. """
  342. def __init__(self, estimator_action: EstimatorAction):
  343. self.estimator_action = estimator_action
  344. def to_model_input_form(self, input_tensors: ReaderInputTensors):
  345. inputs = (input_tensors.path_source_token_indices, input_tensors.path_indices,
  346. input_tensors.path_target_token_indices, input_tensors.context_valid_mask)
  347. if self.estimator_action.is_train:
  348. targets = input_tensors.target_index
  349. else:
  350. targets = {'target_index': input_tensors.target_index, 'target_string': input_tensors.target_string}
  351. if self.estimator_action.is_predict:
  352. inputs += (input_tensors.path_source_token_strings, input_tensors.path_strings,
  353. input_tensors.path_target_token_strings)
  354. return inputs, targets
  355. def from_model_input_form(self, input_row) -> ReaderInputTensors:
  356. inputs, targets = input_row
  357. return ReaderInputTensors(
  358. path_source_token_indices=inputs[0],
  359. path_indices=inputs[1],
  360. path_target_token_indices=inputs[2],
  361. context_valid_mask=inputs[3],
  362. target_index=targets if self.estimator_action.is_train else targets['target_index'],
  363. target_string=targets['target_string'] if not self.estimator_action.is_train else None,
  364. path_source_token_strings=inputs[4] if self.estimator_action.is_predict else None,
  365. path_strings=inputs[5] if self.estimator_action.is_predict else None,
  366. path_target_token_strings=inputs[6] if self.estimator_action.is_predict else None
  367. )
  368. """Used for convenient-and-clear access to raw prediction result parts (by their names)."""
  369. KerasPredictionModelOutput = namedtuple(
  370. 'KerasModelOutput', ['target_index', 'code_vectors', 'attention_weights',
  371. 'topk_predicted_words', 'topk_predicted_words_scores'])
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...