Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

interactive_predict.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  1. import traceback
  2. from keras.models import Model
  3. from common import common
  4. from extractor import Extractor
  5. SHOW_TOP_CONTEXTS = 10
  6. MAX_PATH_LENGTH = 8
  7. MAX_PATH_WIDTH = 1
  8. JAR_PATH = 'cd2vec/cli.jar'
  9. class InteractivePredictor:
  10. exit_keywords = ['exit', 'quit', 'q']
  11. def __init__(self, config, model):
  12. # model.predict([])
  13. self.model = model
  14. self.config = config
  15. self.path_extractor = Extractor(config,
  16. jar_path=JAR_PATH,
  17. max_path_length=MAX_PATH_LENGTH,
  18. max_path_width=MAX_PATH_WIDTH)
  19. def read_file(self, input_filename):
  20. with open(input_filename, 'r') as file:
  21. return file.readlines()
  22. def predict(self):
  23. input_filename = 'pred_files/Input.py'
  24. print('Starting interactive prediction...')
  25. while True:
  26. print(
  27. 'Modify the file: "%s" and press any key when ready, or "q" / "quit" / "exit" to exit' % input_filename)
  28. user_input = input()
  29. if user_input.lower() in self.exit_keywords:
  30. print('Exiting...')
  31. return
  32. try:
  33. predict_lines, hash_to_string_dict = self.path_extractor.extract_paths(input_filename)
  34. except ValueError as e:
  35. print(e)
  36. raw_prediction_results = self.model.predict(predict_lines)
  37. method_prediction_results = common.parse_prediction_results(
  38. raw_prediction_results, hash_to_string_dict,
  39. self.model.vocabs.target_vocab.special_words, topk=SHOW_TOP_CONTEXTS)
  40. for raw_prediction, method_prediction in zip(raw_prediction_results, method_prediction_results):
  41. print('Original name:\t' + method_prediction.original_name)
  42. for name_prob_pair in method_prediction.predictions:
  43. print('\t(%f) predicted: %s' % (name_prob_pair['probability'], name_prob_pair['name']))
  44. print('Attention:')
  45. for attention_obj in method_prediction.attention_paths:
  46. print('%f\tcontext: %s,%s,%s' % (
  47. attention_obj['score'], attention_obj['token1'], attention_obj['path'], attention_obj['token2']))
  48. if self.config.EXPORT_CODE_VECTORS:
  49. print('Code vector:')
  50. print(' '.join(map(str, raw_prediction.code_vector)))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...