Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

interactive.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  1. #!/usr/bin/env python3 -u
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the LICENSE file in
  6. # the root directory of this source tree. An additional grant of patent rights
  7. # can be found in the PATENTS file in the same directory.
  8. import sys
  9. import torch
  10. from torch.autograd import Variable
  11. from fairseq import options, tokenizer, utils
  12. from fairseq.sequence_generator import SequenceGenerator
  13. def main(args):
  14. print(args)
  15. assert not args.sampling or args.nbest == args.beam, \
  16. '--sampling requires --nbest to be equal to --beam'
  17. assert not args.max_sentences, \
  18. '--max-sentences/--batch-size is not supported in interactive mode'
  19. use_cuda = torch.cuda.is_available() and not args.cpu
  20. # Load ensemble
  21. print('| loading model(s) from {}'.format(', '.join(args.path)))
  22. models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
  23. src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
  24. print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
  25. print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
  26. # Optimize ensemble for generation
  27. for model in models:
  28. model.make_generation_fast_(
  29. beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
  30. )
  31. # Initialize generator
  32. translator = SequenceGenerator(
  33. models, beam_size=args.beam, stop_early=(not args.no_early_stop),
  34. normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
  35. unk_penalty=args.unkpen)
  36. if use_cuda:
  37. translator.cuda()
  38. # Load alignment dictionary for unknown word replacement
  39. # (None if no unknown word replacement, empty if no path to align dictionary)
  40. align_dict = utils.load_align_dict(args.replace_unk)
  41. print('| Type the input sentence and press return:')
  42. for src_str in sys.stdin:
  43. src_str = src_str.strip()
  44. src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
  45. if use_cuda:
  46. src_tokens = src_tokens.cuda()
  47. src_lengths = src_tokens.new([src_tokens.numel()])
  48. translations = translator.generate(
  49. Variable(src_tokens.view(1, -1)),
  50. Variable(src_lengths.view(-1)),
  51. )
  52. hypos = translations[0]
  53. print('O\t{}'.format(src_str))
  54. # Process top predictions
  55. for hypo in hypos[:min(len(hypos), args.nbest)]:
  56. hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
  57. hypo_tokens=hypo['tokens'].int().cpu(),
  58. src_str=src_str,
  59. alignment=hypo['alignment'].int().cpu(),
  60. align_dict=align_dict,
  61. dst_dict=dst_dict,
  62. remove_bpe=args.remove_bpe,
  63. )
  64. print('H\t{}\t{}'.format(hypo['score'], hypo_str))
  65. print('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
  66. if __name__ == '__main__':
  67. parser = options.get_generation_parser()
  68. args = parser.parse_args()
  69. main(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...