Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sequence_scorer.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch
  8. from fairseq import utils
  9. class SequenceScorer(object):
  10. """Scores the target for a given source sentence."""
  11. def __init__(self, models, tgt_dict):
  12. self.models = models
  13. self.pad = tgt_dict.pad()
  14. def cuda(self):
  15. for model in self.models:
  16. model.cuda()
  17. return self
  18. def score_batched_itr(self, data_itr, cuda=False, timer=None):
  19. """Iterate over a batched dataset and yield scored translations."""
  20. for sample in data_itr:
  21. s = utils.move_to_cuda(sample) if cuda else sample
  22. if timer is not None:
  23. timer.start()
  24. pos_scores, attn = self.score(s)
  25. for i, id in enumerate(s['id'].data):
  26. # remove padding from ref
  27. src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
  28. ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
  29. tgt_len = ref.numel()
  30. pos_scores_i = pos_scores[i][:tgt_len]
  31. score_i = pos_scores_i.sum() / tgt_len
  32. if attn is not None:
  33. attn_i = attn[i]
  34. _, alignment = attn_i.max(dim=0)
  35. else:
  36. attn_i = alignment = None
  37. hypos = [{
  38. 'tokens': ref,
  39. 'score': score_i,
  40. 'attention': attn_i,
  41. 'alignment': alignment,
  42. 'positional_scores': pos_scores_i,
  43. }]
  44. if timer is not None:
  45. timer.stop(s['ntokens'])
  46. # return results in the same format as SequenceGenerator
  47. yield id, src, ref, hypos
  48. def score(self, sample):
  49. """Score a batch of translations."""
  50. net_input = sample['net_input']
  51. # compute scores for each model in the ensemble
  52. avg_probs = None
  53. avg_attn = None
  54. for model in self.models:
  55. with torch.no_grad():
  56. model.eval()
  57. decoder_out = model.forward(**net_input)
  58. attn = decoder_out[1]
  59. probs = model.get_normalized_probs(decoder_out, log_probs=len(self.models) == 1, sample=sample).data
  60. if avg_probs is None:
  61. avg_probs = probs
  62. else:
  63. avg_probs.add_(probs)
  64. if attn is not None and torch.is_tensor(attn):
  65. attn = attn.data
  66. if avg_attn is None:
  67. avg_attn = attn
  68. else:
  69. avg_attn.add_(attn)
  70. if len(self.models) > 1:
  71. avg_probs.div_(len(self.models))
  72. avg_probs.log_()
  73. if avg_attn is not None:
  74. avg_attn.div_(len(self.models))
  75. avg_probs = avg_probs.gather(
  76. dim=2,
  77. index=sample['target'].data.unsqueeze(-1),
  78. )
  79. return avg_probs.squeeze(2), avg_attn
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...