1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
|
- # Copyright (c) 2017-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the LICENSE file in
- # the root directory of this source tree. An additional grant of patent rights
- # can be found in the PATENTS file in the same directory.
- import math
- import torch
- from fairseq import search, utils
- from fairseq.models import FairseqIncrementalDecoder
- class SequenceGenerator(object):
- def __init__(
- self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
- normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False,
- sampling=False, sampling_topk=-1, sampling_temperature=1.,
- diverse_beam_groups=-1, diverse_beam_strength=0.5,
- match_source_len=False, no_repeat_ngram_size=0
- ):
- """Generates translations of a given source sentence.
- Args:
- beam_size (int, optional): beam width (default: 1)
- min/maxlen (int, optional): the length of the generated output will
- be bounded by minlen and maxlen (not including end-of-sentence)
- stop_early (bool, optional): stop generation immediately after we
- finalize beam_size hypotheses, even though longer hypotheses
- might have better normalized scores (default: True)
- normalize_scores (bool, optional): normalize scores by the length
- of the output (default: True)
- len_penalty (float, optional): length penalty, where <1.0 favors
- shorter, >1.0 favors longer sentences (default: 1.0)
- unk_penalty (float, optional): unknown word penalty, where <0
- produces more unks, >0 produces fewer (default: 0.0)
- retain_dropout (bool, optional): use dropout when generating
- (default: False)
- sampling (bool, optional): sample outputs instead of beam search
- (default: False)
- sampling_topk (int, optional): only sample among the top-k choices
- at each step (default: -1)
- sampling_temperature (float, optional): temperature for sampling,
- where values >1.0 produces more uniform sampling and values
- <1.0 produces sharper sampling (default: 1.0)
- diverse_beam_groups/strength (float, optional): parameters for
- Diverse Beam Search sampling
- match_source_len (bool, optional): outputs should match the source
- length (default: False)
- """
- self.models = models
- self.pad = tgt_dict.pad()
- self.unk = tgt_dict.unk()
- self.eos = tgt_dict.eos()
- self.vocab_size = len(tgt_dict)
- self.beam_size = beam_size
- self.minlen = minlen
- max_decoder_len = min(m.max_decoder_positions() for m in self.models)
- max_decoder_len -= 1 # we define maxlen not including the EOS marker
- self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
- self.stop_early = stop_early
- self.normalize_scores = normalize_scores
- self.len_penalty = len_penalty
- self.unk_penalty = unk_penalty
- self.retain_dropout = retain_dropout
- self.match_source_len = match_source_len
- self.no_repeat_ngram_size = no_repeat_ngram_size
- assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
- if sampling:
- self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
- elif diverse_beam_groups > 0:
- self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
- elif match_source_len:
- self.search = search.LengthConstrainedBeamSearch(
- tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
- )
- else:
- self.search = search.BeamSearch(tgt_dict)
- def cuda(self):
- for model in self.models:
- model.cuda()
- return self
- def generate_batched_itr(
- self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
- cuda=False, timer=None, prefix_size=0,
- ):
- """Iterate over a batched dataset and yield individual translations.
- Args:
- maxlen_a/b (int, optional): generate sequences of maximum length
- ``ax + b``, where ``x`` is the source sentence length.
- cuda (bool, optional): use GPU for generation
- timer (StopwatchMeter, optional): time generations
- prefix_size (int, optional): prefill the generation with the gold
- prefix up to this length.
- """
- if maxlen_b is None:
- maxlen_b = self.maxlen
- for sample in data_itr:
- s = utils.move_to_cuda(sample) if cuda else sample
- if 'net_input' not in s:
- continue
- input = s['net_input']
- # model.forward normally channels prev_output_tokens into the decoder
- # separately, but SequenceGenerator directly calls model.encoder
- encoder_input = {
- k: v for k, v in input.items()
- if k != 'prev_output_tokens'
- }
- srclen = encoder_input['src_tokens'].size(1)
- if timer is not None:
- timer.start()
- with torch.no_grad():
- hypos = self.generate(
- encoder_input,
- beam_size=beam_size,
- maxlen=int(maxlen_a*srclen + maxlen_b),
- prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
- )
- if timer is not None:
- timer.stop(sum(len(h[0]['tokens']) for h in hypos))
- for i, id in enumerate(s['id'].data):
- # remove padding
- src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
- ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
- yield id, src, ref, hypos[i]
- def generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
- """Generate a batch of translations.
- Args:
- encoder_input (dict): dictionary containing the inputs to
- *model.encoder.forward*.
- beam_size (int, optional): overriding the beam size
- (default: *self.beam_size*).
- max_len (int, optional): maximum length of the generated sequence
- prefix_tokens (LongTensor, optional): force decoder to begin with
- these tokens
- """
- with torch.no_grad():
- return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
- def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
- """See generate"""
- src_tokens = encoder_input['src_tokens']
- src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
- bsz, srclen = src_tokens.size()
- maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
- if self.match_source_len:
- maxlen = src_lengths.max().item()
- # the max beam size is the dictionary size - 1, since we never select pad
- beam_size = beam_size if beam_size is not None else self.beam_size
- beam_size = min(beam_size, self.vocab_size - 1)
- encoder_outs = []
- incremental_states = {}
- for model in self.models:
- if not self.retain_dropout:
- model.eval()
- if isinstance(model.decoder, FairseqIncrementalDecoder):
- incremental_states[model] = {}
- else:
- incremental_states[model] = None
- # compute the encoder output for each beam
- encoder_out = model.encoder(**encoder_input)
- new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
- new_order = new_order.to(src_tokens.device).long()
- encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
- encoder_outs.append(encoder_out)
- # initialize buffers
- scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
- scores_buf = scores.clone()
- tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
- tokens_buf = tokens.clone()
- tokens[:, 0] = self.eos
- attn, attn_buf = None, None
- nonpad_idxs = None
- # list of completed sentences
- finalized = [[] for i in range(bsz)]
- finished = [False for i in range(bsz)]
- worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
- num_remaining_sent = bsz
- # number of candidate hypos per step
- cand_size = 2 * beam_size # 2 x beam size in case half are EOS
- # offset arrays for converting between different indexing schemes
- bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
- cand_offsets = torch.arange(0, cand_size).type_as(tokens)
- # helper function for allocating buffers on the fly
- buffers = {}
- def buffer(name, type_of=tokens): # noqa
- if name not in buffers:
- buffers[name] = type_of.new()
- return buffers[name]
- def is_finished(sent, step, unfinalized_scores=None):
- """
- Check whether we've finished generation for a given sentence, by
- comparing the worst score among finalized hypotheses to the best
- possible score among unfinalized hypotheses.
- """
- assert len(finalized[sent]) <= beam_size
- if len(finalized[sent]) == beam_size:
- if self.stop_early or step == maxlen or unfinalized_scores is None:
- return True
- # stop if the best unfinalized score is worse than the worst
- # finalized one
- best_unfinalized_score = unfinalized_scores[sent].max()
- if self.normalize_scores:
- best_unfinalized_score /= maxlen ** self.len_penalty
- if worst_finalized[sent]['score'] >= best_unfinalized_score:
- return True
- return False
- def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
- """
- Finalize the given hypotheses at this step, while keeping the total
- number of finalized hypotheses per sentence <= beam_size.
- Note: the input must be in the desired finalization order, so that
- hypotheses that appear earlier in the input are preferred to those
- that appear later.
- Args:
- step: current time step
- bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
- indicating which hypotheses to finalize
- eos_scores: A vector of the same size as bbsz_idx containing
- scores for each hypothesis
- unfinalized_scores: A vector containing scores for all
- unfinalized hypotheses
- """
- assert bbsz_idx.numel() == eos_scores.numel()
- # clone relevant token and attention tensors
- tokens_clone = tokens.index_select(0, bbsz_idx)
- tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
- tokens_clone[:, step] = self.eos
- attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
- # compute scores per token position
- pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
- pos_scores[:, step] = eos_scores
- # convert from cumulative to per-position scores
- pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
- # normalize sentence-level scores
- if self.normalize_scores:
- eos_scores /= (step + 1) ** self.len_penalty
- cum_unfin = []
- prev = 0
- for f in finished:
- if f:
- prev += 1
- else:
- cum_unfin.append(prev)
- sents_seen = set()
- for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
- unfin_idx = idx // beam_size
- sent = unfin_idx + cum_unfin[unfin_idx]
- sents_seen.add((sent, unfin_idx))
- if self.match_source_len and step > src_lengths[unfin_idx]:
- score = -math.inf
- def get_hypo():
- if attn_clone is not None:
- # remove padding tokens from attn scores
- hypo_attn = attn_clone[i][nonpad_idxs[sent]]
- _, alignment = hypo_attn.max(dim=0)
- else:
- hypo_attn = None
- alignment = None
- return {
- 'tokens': tokens_clone[i],
- 'score': score,
- 'attention': hypo_attn, # src_len x tgt_len
- 'alignment': alignment,
- 'positional_scores': pos_scores[i],
- }
- if len(finalized[sent]) < beam_size:
- finalized[sent].append(get_hypo())
- elif not self.stop_early and score > worst_finalized[sent]['score']:
- # replace worst hypo for this sentence with new/better one
- worst_idx = worst_finalized[sent]['idx']
- if worst_idx is not None:
- finalized[sent][worst_idx] = get_hypo()
- # find new worst finalized hypo for this sentence
- idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
- worst_finalized[sent] = {
- 'score': s['score'],
- 'idx': idx,
- }
- newly_finished = []
- for sent, unfin_idx in sents_seen:
- # check termination conditions for this sentence
- if not finished[sent] and is_finished(sent, step, unfinalized_scores):
- finished[sent] = True
- newly_finished.append(unfin_idx)
- return newly_finished
- reorder_state = None
- batch_idxs = None
- for step in range(maxlen + 1): # one extra step for EOS marker
- # reorder decoder internal states based on the prev choice of beams
- if reorder_state is not None:
- if batch_idxs is not None:
- # update beam indices to take into account removed sentences
- corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
- reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
- for i, model in enumerate(self.models):
- if isinstance(model.decoder, FairseqIncrementalDecoder):
- model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
- encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
- lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
- lprobs[:, self.pad] = -math.inf # never select pad
- lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
- if self.no_repeat_ngram_size > 0:
- # for each beam and batch sentence, generate a list of previous ngrams
- gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
- for bbsz_idx in range(bsz * beam_size):
- gen_tokens = tokens[bbsz_idx].tolist()
- for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
- gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
- gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
- # Record attention scores
- if avg_attn_scores is not None:
- if attn is None:
- attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
- attn_buf = attn.clone()
- nonpad_idxs = src_tokens.ne(self.pad)
- attn[:, :, step + 1].copy_(avg_attn_scores)
- scores = scores.type_as(lprobs)
- scores_buf = scores_buf.type_as(lprobs)
- eos_bbsz_idx = buffer('eos_bbsz_idx')
- eos_scores = buffer('eos_scores', type_of=scores)
- if step < maxlen:
- self.search.set_src_lengths(src_lengths)
- if self.no_repeat_ngram_size > 0:
- def calculate_banned_tokens(bbsz_idx):
- # before decoding the next token, prevent decoding of ngrams that have already appeared
- ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
- return gen_ngrams[bbsz_idx].get(ngram_index, [])
- if step + 2 - self.no_repeat_ngram_size >= 0:
- # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
- banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
- else:
- banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
- for bbsz_idx in range(bsz * beam_size):
- lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = float('-Inf')
- if prefix_tokens is not None and step < prefix_tokens.size(1):
- probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
- cand_scores = torch.gather(
- probs_slice, dim=1,
- index=prefix_tokens[:, step].view(-1, 1).data
- ).expand(-1, cand_size)
- cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
- cand_beams = torch.zeros_like(cand_indices)
- else:
- cand_scores, cand_indices, cand_beams = self.search.step(
- step,
- lprobs.view(bsz, -1, self.vocab_size),
- scores.view(bsz, beam_size, -1)[:, :, :step],
- )
- else:
- # make probs contain cumulative scores for each hypothesis
- lprobs.add_(scores[:, step - 1].unsqueeze(-1))
- # finalize all active hypotheses once we hit maxlen
- # pick the hypothesis with the highest prob of EOS right now
- torch.sort(
- lprobs[:, self.eos],
- descending=True,
- out=(eos_scores, eos_bbsz_idx),
- )
- num_remaining_sent -= len(finalize_hypos(step, eos_bbsz_idx, eos_scores))
- assert num_remaining_sent == 0
- break
- # cand_bbsz_idx contains beam indices for the top candidate
- # hypotheses, with a range of values: [0, bsz*beam_size),
- # and dimensions: [bsz, cand_size]
- cand_bbsz_idx = cand_beams.add(bbsz_offsets)
- # finalize hypotheses that end in eos
- eos_mask = cand_indices.eq(self.eos)
- finalized_sents = set()
- if step >= self.minlen:
- # only consider eos when it's among the top beam_size indices
- torch.masked_select(
- cand_bbsz_idx[:, :beam_size],
- mask=eos_mask[:, :beam_size],
- out=eos_bbsz_idx,
- )
- if eos_bbsz_idx.numel() > 0:
- torch.masked_select(
- cand_scores[:, :beam_size],
- mask=eos_mask[:, :beam_size],
- out=eos_scores,
- )
- finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores, cand_scores)
- num_remaining_sent -= len(finalized_sents)
- assert num_remaining_sent >= 0
- if num_remaining_sent == 0:
- break
- assert step < maxlen
- if len(finalized_sents) > 0:
- new_bsz = bsz - len(finalized_sents)
- # construct batch_idxs which holds indices of batches to keep for the next pass
- batch_mask = cand_indices.new_ones(bsz)
- batch_mask[cand_indices.new(finalized_sents)] = 0
- batch_idxs = batch_mask.nonzero().squeeze(-1)
- eos_mask = eos_mask[batch_idxs]
- cand_beams = cand_beams[batch_idxs]
- bbsz_offsets.resize_(new_bsz, 1)
- cand_bbsz_idx = cand_beams.add(bbsz_offsets)
- cand_scores = cand_scores[batch_idxs]
- cand_indices = cand_indices[batch_idxs]
- if prefix_tokens is not None:
- prefix_tokens = prefix_tokens[batch_idxs]
- src_lengths = src_lengths[batch_idxs]
- scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
- scores_buf.resize_as_(scores)
- tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
- tokens_buf.resize_as_(tokens)
- if attn is not None:
- attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
- attn_buf.resize_as_(attn)
- bsz = new_bsz
- else:
- batch_idxs = None
- # set active_mask so that values > cand_size indicate eos hypos
- # and values < cand_size indicate candidate active hypos.
- # After, the min values per row are the top candidate active hypos
- active_mask = buffer('active_mask')
- torch.add(
- eos_mask.type_as(cand_offsets) * cand_size,
- cand_offsets[:eos_mask.size(1)],
- out=active_mask,
- )
- # get the top beam_size active hypotheses, which are just the hypos
- # with the smallest values in active_mask
- active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
- torch.topk(
- active_mask, k=beam_size, dim=1, largest=False,
- out=(_ignore, active_hypos)
- )
- active_bbsz_idx = buffer('active_bbsz_idx')
- torch.gather(
- cand_bbsz_idx, dim=1, index=active_hypos,
- out=active_bbsz_idx,
- )
- active_scores = torch.gather(
- cand_scores, dim=1, index=active_hypos,
- out=scores[:, step].view(bsz, beam_size),
- )
- active_bbsz_idx = active_bbsz_idx.view(-1)
- active_scores = active_scores.view(-1)
- # copy tokens and scores for active hypotheses
- torch.index_select(
- tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
- out=tokens_buf[:, :step + 1],
- )
- torch.gather(
- cand_indices, dim=1, index=active_hypos,
- out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
- )
- if step > 0:
- torch.index_select(
- scores[:, :step], dim=0, index=active_bbsz_idx,
- out=scores_buf[:, :step],
- )
- torch.gather(
- cand_scores, dim=1, index=active_hypos,
- out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
- )
- # copy attention for active hypotheses
- if attn is not None:
- torch.index_select(
- attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
- out=attn_buf[:, :, :step + 2],
- )
- # swap buffers
- tokens, tokens_buf = tokens_buf, tokens
- scores, scores_buf = scores_buf, scores
- if attn is not None:
- attn, attn_buf = attn_buf, attn
- # reorder incremental state in decoder
- reorder_state = active_bbsz_idx
- # sort by score descending
- for sent in range(len(finalized)):
- finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
- return finalized
- def _decode(self, tokens, encoder_outs, incremental_states):
- if len(self.models) == 1:
- return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
- log_probs = []
- avg_attn = None
- for model, encoder_out in zip(self.models, encoder_outs):
- probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=True)
- log_probs.append(probs)
- if attn is not None:
- if avg_attn is None:
- avg_attn = attn
- else:
- avg_attn.add_(attn)
- avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(len(self.models))
- if avg_attn is not None:
- avg_attn.div_(len(self.models))
- return avg_probs, avg_attn
- def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
- with torch.no_grad():
- if incremental_states[model] is not None:
- decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
- else:
- decoder_out = list(model.decoder(tokens, encoder_out))
- decoder_out[0] = decoder_out[0][:, -1, :]
- attn = decoder_out[1]
- if type(attn) is dict:
- attn = attn['attn']
- if attn is not None:
- if type(attn) is dict:
- attn = attn['attn']
- attn = attn[:, -1, :]
- probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
- return probs, attn
|