Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

score.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the LICENSE file in
  6. # the root directory of this source tree. An additional grant of patent rights
  7. # can be found in the PATENTS file in the same directory.
  8. """
  9. BLEU scoring of generated translations against reference translations.
  10. """
  11. import argparse
  12. import os
  13. import sys
  14. from fairseq import bleu, tokenizer
  15. from fairseq.data import dictionary
  16. def get_parser():
  17. parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
  18. # fmt: off
  19. parser.add_argument('-s', '--sys', default='-', help='system output')
  20. parser.add_argument('-r', '--ref', required=True, help='references')
  21. parser.add_argument('-o', '--order', default=4, metavar='N',
  22. type=int, help='consider ngrams up to this order')
  23. parser.add_argument('--ignore-case', action='store_true',
  24. help='case-insensitive scoring')
  25. parser.add_argument('--sacrebleu', action='store_true',
  26. help='score with sacrebleu')
  27. # fmt: on
  28. return parser
  29. def main():
  30. parser = get_parser()
  31. args = parser.parse_args()
  32. print(args)
  33. assert args.sys == '-' or os.path.exists(args.sys), \
  34. "System output file {} does not exist".format(args.sys)
  35. assert os.path.exists(args.ref), \
  36. "Reference file {} does not exist".format(args.ref)
  37. dict = dictionary.Dictionary()
  38. def readlines(fd):
  39. for line in fd.readlines():
  40. if args.ignore_case:
  41. yield line.lower()
  42. else:
  43. yield line
  44. if args.sacrebleu:
  45. import sacrebleu
  46. def score(fdsys):
  47. with open(args.ref) as fdref:
  48. print(sacrebleu.corpus_bleu(fdsys, [fdref]))
  49. else:
  50. def score(fdsys):
  51. with open(args.ref) as fdref:
  52. scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
  53. for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
  54. sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
  55. ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
  56. scorer.add(ref_tok, sys_tok)
  57. print(scorer.result_string(args.order))
  58. if args.sys == '-':
  59. score(sys.stdin)
  60. else:
  61. with open(args.sys, 'r') as f:
  62. score(f)
  63. if __name__ == '__main__':
  64. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...