Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

interactive.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the LICENSE file in
  6. # the root directory of this source tree. An additional grant of patent rights
  7. # can be found in the PATENTS file in the same directory.
  8. #
  9. import sys
  10. import torch
  11. from torch.autograd import Variable
  12. from fairseq import options, tokenizer, utils
  13. from fairseq.sequence_generator import SequenceGenerator
  14. def main():
  15. parser = options.get_parser('Generation')
  16. parser.add_argument('--path', metavar='FILE', required=True, action='append',
  17. help='path(s) to model file(s)')
  18. options.add_dataset_args(parser)
  19. options.add_generation_args(parser)
  20. args = parser.parse_args()
  21. print(args)
  22. use_cuda = torch.cuda.is_available() and not args.cpu
  23. # Load ensemble
  24. print('| loading model(s) from {}'.format(', '.join(args.path)))
  25. models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
  26. src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
  27. print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
  28. print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
  29. # Optimize ensemble for generation
  30. for model in models:
  31. model.make_generation_fast_(
  32. beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
  33. # Initialize generator
  34. translator = SequenceGenerator(
  35. models, beam_size=args.beam, stop_early=(not args.no_early_stop),
  36. normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
  37. unk_penalty=args.unkpen)
  38. if use_cuda:
  39. translator.cuda()
  40. # Load alignment dictionary for unknown word replacement
  41. # (None if no unknown word replacement, empty if no path to align dictionary)
  42. align_dict = utils.load_align_dict(args.replace_unk)
  43. print('| Type the input sentence and press return:')
  44. for src_str in sys.stdin:
  45. src_str = src_str.strip()
  46. src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
  47. if use_cuda:
  48. src_tokens = src_tokens.cuda()
  49. translations = translator.generate(Variable(src_tokens.view(1, -1)))
  50. hypos = translations[0]
  51. print('O\t{}'.format(src_str))
  52. # Process top predictions
  53. for hypo in hypos[:min(len(hypos), args.nbest)]:
  54. hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
  55. hypo_tokens=hypo['tokens'].int().cpu(),
  56. src_str=src_str,
  57. alignment=hypo['alignment'].int().cpu(),
  58. align_dict=align_dict,
  59. dst_dict=dst_dict,
  60. remove_bpe=args.remove_bpe)
  61. print('H\t{}\t{}'.format(hypo['score'], hypo_str))
  62. print('A\t{}'.format(' '.join(map(str, alignment))))
  63. if __name__ == '__main__':
  64. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...