Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sequence_generator.py 26 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import math
  8. import torch
  9. from fairseq import search, utils
  10. from fairseq.models import FairseqIncrementalDecoder
  11. class SequenceGenerator(object):
  12. def __init__(
  13. self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
  14. normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False,
  15. sampling=False, sampling_topk=-1, sampling_temperature=1.,
  16. diverse_beam_groups=-1, diverse_beam_strength=0.5,
  17. match_source_len=False, no_repeat_ngram_size=0
  18. ):
  19. """Generates translations of a given source sentence.
  20. Args:
  21. beam_size (int, optional): beam width (default: 1)
  22. min/maxlen (int, optional): the length of the generated output will
  23. be bounded by minlen and maxlen (not including end-of-sentence)
  24. stop_early (bool, optional): stop generation immediately after we
  25. finalize beam_size hypotheses, even though longer hypotheses
  26. might have better normalized scores (default: True)
  27. normalize_scores (bool, optional): normalize scores by the length
  28. of the output (default: True)
  29. len_penalty (float, optional): length penalty, where <1.0 favors
  30. shorter, >1.0 favors longer sentences (default: 1.0)
  31. unk_penalty (float, optional): unknown word penalty, where <0
  32. produces more unks, >0 produces fewer (default: 0.0)
  33. retain_dropout (bool, optional): use dropout when generating
  34. (default: False)
  35. sampling (bool, optional): sample outputs instead of beam search
  36. (default: False)
  37. sampling_topk (int, optional): only sample among the top-k choices
  38. at each step (default: -1)
  39. sampling_temperature (float, optional): temperature for sampling,
  40. where values >1.0 produces more uniform sampling and values
  41. <1.0 produces sharper sampling (default: 1.0)
  42. diverse_beam_groups/strength (float, optional): parameters for
  43. Diverse Beam Search sampling
  44. match_source_len (bool, optional): outputs should match the source
  45. length (default: False)
  46. """
  47. self.models = models
  48. self.pad = tgt_dict.pad()
  49. self.unk = tgt_dict.unk()
  50. self.eos = tgt_dict.eos()
  51. self.vocab_size = len(tgt_dict)
  52. self.beam_size = beam_size
  53. self.minlen = minlen
  54. max_decoder_len = min(m.max_decoder_positions() for m in self.models)
  55. max_decoder_len -= 1 # we define maxlen not including the EOS marker
  56. self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
  57. self.stop_early = stop_early
  58. self.normalize_scores = normalize_scores
  59. self.len_penalty = len_penalty
  60. self.unk_penalty = unk_penalty
  61. self.retain_dropout = retain_dropout
  62. self.match_source_len = match_source_len
  63. self.no_repeat_ngram_size = no_repeat_ngram_size
  64. assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
  65. if sampling:
  66. self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
  67. elif diverse_beam_groups > 0:
  68. self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
  69. elif match_source_len:
  70. self.search = search.LengthConstrainedBeamSearch(
  71. tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
  72. )
  73. else:
  74. self.search = search.BeamSearch(tgt_dict)
  75. def cuda(self):
  76. for model in self.models:
  77. model.cuda()
  78. return self
  79. def generate_batched_itr(
  80. self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
  81. cuda=False, timer=None, prefix_size=0,
  82. ):
  83. """Iterate over a batched dataset and yield individual translations.
  84. Args:
  85. maxlen_a/b (int, optional): generate sequences of maximum length
  86. ``ax + b``, where ``x`` is the source sentence length.
  87. cuda (bool, optional): use GPU for generation
  88. timer (StopwatchMeter, optional): time generations
  89. prefix_size (int, optional): prefill the generation with the gold
  90. prefix up to this length.
  91. """
  92. if maxlen_b is None:
  93. maxlen_b = self.maxlen
  94. for sample in data_itr:
  95. s = utils.move_to_cuda(sample) if cuda else sample
  96. if 'net_input' not in s:
  97. continue
  98. input = s['net_input']
  99. # model.forward normally channels prev_output_tokens into the decoder
  100. # separately, but SequenceGenerator directly calls model.encoder
  101. encoder_input = {
  102. k: v for k, v in input.items()
  103. if k != 'prev_output_tokens'
  104. }
  105. srclen = encoder_input['src_tokens'].size(1)
  106. if timer is not None:
  107. timer.start()
  108. with torch.no_grad():
  109. hypos = self.generate(
  110. encoder_input,
  111. beam_size=beam_size,
  112. maxlen=int(maxlen_a*srclen + maxlen_b),
  113. prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
  114. )
  115. if timer is not None:
  116. timer.stop(sum(len(h[0]['tokens']) for h in hypos))
  117. for i, id in enumerate(s['id'].data):
  118. # remove padding
  119. src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
  120. ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
  121. yield id, src, ref, hypos[i]
  122. def generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
  123. """Generate a batch of translations.
  124. Args:
  125. encoder_input (dict): dictionary containing the inputs to
  126. *model.encoder.forward*.
  127. beam_size (int, optional): overriding the beam size
  128. (default: *self.beam_size*).
  129. max_len (int, optional): maximum length of the generated sequence
  130. prefix_tokens (LongTensor, optional): force decoder to begin with
  131. these tokens
  132. """
  133. with torch.no_grad():
  134. return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
  135. def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
  136. """See generate"""
  137. src_tokens = encoder_input['src_tokens']
  138. src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
  139. bsz, srclen = src_tokens.size()
  140. maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
  141. if self.match_source_len:
  142. maxlen = src_lengths.max().item()
  143. # the max beam size is the dictionary size - 1, since we never select pad
  144. beam_size = beam_size if beam_size is not None else self.beam_size
  145. beam_size = min(beam_size, self.vocab_size - 1)
  146. encoder_outs = []
  147. incremental_states = {}
  148. for model in self.models:
  149. if not self.retain_dropout:
  150. model.eval()
  151. if isinstance(model.decoder, FairseqIncrementalDecoder):
  152. incremental_states[model] = {}
  153. else:
  154. incremental_states[model] = None
  155. # compute the encoder output for each beam
  156. encoder_out = model.encoder(**encoder_input)
  157. new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
  158. new_order = new_order.to(src_tokens.device).long()
  159. encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
  160. encoder_outs.append(encoder_out)
  161. # initialize buffers
  162. scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
  163. scores_buf = scores.clone()
  164. tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
  165. tokens_buf = tokens.clone()
  166. tokens[:, 0] = self.eos
  167. attn, attn_buf = None, None
  168. nonpad_idxs = None
  169. # list of completed sentences
  170. finalized = [[] for i in range(bsz)]
  171. finished = [False for i in range(bsz)]
  172. worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
  173. num_remaining_sent = bsz
  174. # number of candidate hypos per step
  175. cand_size = 2 * beam_size # 2 x beam size in case half are EOS
  176. # offset arrays for converting between different indexing schemes
  177. bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
  178. cand_offsets = torch.arange(0, cand_size).type_as(tokens)
  179. # helper function for allocating buffers on the fly
  180. buffers = {}
  181. def buffer(name, type_of=tokens): # noqa
  182. if name not in buffers:
  183. buffers[name] = type_of.new()
  184. return buffers[name]
  185. def is_finished(sent, step, unfinalized_scores=None):
  186. """
  187. Check whether we've finished generation for a given sentence, by
  188. comparing the worst score among finalized hypotheses to the best
  189. possible score among unfinalized hypotheses.
  190. """
  191. assert len(finalized[sent]) <= beam_size
  192. if len(finalized[sent]) == beam_size:
  193. if self.stop_early or step == maxlen or unfinalized_scores is None:
  194. return True
  195. # stop if the best unfinalized score is worse than the worst
  196. # finalized one
  197. best_unfinalized_score = unfinalized_scores[sent].max()
  198. if self.normalize_scores:
  199. best_unfinalized_score /= maxlen ** self.len_penalty
  200. if worst_finalized[sent]['score'] >= best_unfinalized_score:
  201. return True
  202. return False
  203. def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
  204. """
  205. Finalize the given hypotheses at this step, while keeping the total
  206. number of finalized hypotheses per sentence <= beam_size.
  207. Note: the input must be in the desired finalization order, so that
  208. hypotheses that appear earlier in the input are preferred to those
  209. that appear later.
  210. Args:
  211. step: current time step
  212. bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
  213. indicating which hypotheses to finalize
  214. eos_scores: A vector of the same size as bbsz_idx containing
  215. scores for each hypothesis
  216. unfinalized_scores: A vector containing scores for all
  217. unfinalized hypotheses
  218. """
  219. assert bbsz_idx.numel() == eos_scores.numel()
  220. # clone relevant token and attention tensors
  221. tokens_clone = tokens.index_select(0, bbsz_idx)
  222. tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
  223. tokens_clone[:, step] = self.eos
  224. attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
  225. # compute scores per token position
  226. pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
  227. pos_scores[:, step] = eos_scores
  228. # convert from cumulative to per-position scores
  229. pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
  230. # normalize sentence-level scores
  231. if self.normalize_scores:
  232. eos_scores /= (step + 1) ** self.len_penalty
  233. cum_unfin = []
  234. prev = 0
  235. for f in finished:
  236. if f:
  237. prev += 1
  238. else:
  239. cum_unfin.append(prev)
  240. sents_seen = set()
  241. for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
  242. unfin_idx = idx // beam_size
  243. sent = unfin_idx + cum_unfin[unfin_idx]
  244. sents_seen.add((sent, unfin_idx))
  245. if self.match_source_len and step > src_lengths[unfin_idx]:
  246. score = -math.inf
  247. def get_hypo():
  248. if attn_clone is not None:
  249. # remove padding tokens from attn scores
  250. hypo_attn = attn_clone[i][nonpad_idxs[sent]]
  251. _, alignment = hypo_attn.max(dim=0)
  252. else:
  253. hypo_attn = None
  254. alignment = None
  255. return {
  256. 'tokens': tokens_clone[i],
  257. 'score': score,
  258. 'attention': hypo_attn, # src_len x tgt_len
  259. 'alignment': alignment,
  260. 'positional_scores': pos_scores[i],
  261. }
  262. if len(finalized[sent]) < beam_size:
  263. finalized[sent].append(get_hypo())
  264. elif not self.stop_early and score > worst_finalized[sent]['score']:
  265. # replace worst hypo for this sentence with new/better one
  266. worst_idx = worst_finalized[sent]['idx']
  267. if worst_idx is not None:
  268. finalized[sent][worst_idx] = get_hypo()
  269. # find new worst finalized hypo for this sentence
  270. idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
  271. worst_finalized[sent] = {
  272. 'score': s['score'],
  273. 'idx': idx,
  274. }
  275. newly_finished = []
  276. for sent, unfin_idx in sents_seen:
  277. # check termination conditions for this sentence
  278. if not finished[sent] and is_finished(sent, step, unfinalized_scores):
  279. finished[sent] = True
  280. newly_finished.append(unfin_idx)
  281. return newly_finished
  282. reorder_state = None
  283. batch_idxs = None
  284. for step in range(maxlen + 1): # one extra step for EOS marker
  285. # reorder decoder internal states based on the prev choice of beams
  286. if reorder_state is not None:
  287. if batch_idxs is not None:
  288. # update beam indices to take into account removed sentences
  289. corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
  290. reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
  291. for i, model in enumerate(self.models):
  292. if isinstance(model.decoder, FairseqIncrementalDecoder):
  293. model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
  294. encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
  295. lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
  296. lprobs[:, self.pad] = -math.inf # never select pad
  297. lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
  298. if self.no_repeat_ngram_size > 0:
  299. # for each beam and batch sentence, generate a list of previous ngrams
  300. gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
  301. for bbsz_idx in range(bsz * beam_size):
  302. gen_tokens = tokens[bbsz_idx].tolist()
  303. for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
  304. gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
  305. gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
  306. # Record attention scores
  307. if avg_attn_scores is not None:
  308. if attn is None:
  309. attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
  310. attn_buf = attn.clone()
  311. nonpad_idxs = src_tokens.ne(self.pad)
  312. attn[:, :, step + 1].copy_(avg_attn_scores)
  313. scores = scores.type_as(lprobs)
  314. scores_buf = scores_buf.type_as(lprobs)
  315. eos_bbsz_idx = buffer('eos_bbsz_idx')
  316. eos_scores = buffer('eos_scores', type_of=scores)
  317. if step < maxlen:
  318. self.search.set_src_lengths(src_lengths)
  319. if self.no_repeat_ngram_size > 0:
  320. def calculate_banned_tokens(bbsz_idx):
  321. # before decoding the next token, prevent decoding of ngrams that have already appeared
  322. ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
  323. return gen_ngrams[bbsz_idx].get(ngram_index, [])
  324. if step + 2 - self.no_repeat_ngram_size >= 0:
  325. # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
  326. banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
  327. else:
  328. banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
  329. for bbsz_idx in range(bsz * beam_size):
  330. lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = float('-Inf')
  331. if prefix_tokens is not None and step < prefix_tokens.size(1):
  332. probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
  333. cand_scores = torch.gather(
  334. probs_slice, dim=1,
  335. index=prefix_tokens[:, step].view(-1, 1).data
  336. ).expand(-1, cand_size)
  337. cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
  338. cand_beams = torch.zeros_like(cand_indices)
  339. else:
  340. cand_scores, cand_indices, cand_beams = self.search.step(
  341. step,
  342. lprobs.view(bsz, -1, self.vocab_size),
  343. scores.view(bsz, beam_size, -1)[:, :, :step],
  344. )
  345. else:
  346. # make probs contain cumulative scores for each hypothesis
  347. lprobs.add_(scores[:, step - 1].unsqueeze(-1))
  348. # finalize all active hypotheses once we hit maxlen
  349. # pick the hypothesis with the highest prob of EOS right now
  350. torch.sort(
  351. lprobs[:, self.eos],
  352. descending=True,
  353. out=(eos_scores, eos_bbsz_idx),
  354. )
  355. num_remaining_sent -= len(finalize_hypos(step, eos_bbsz_idx, eos_scores))
  356. assert num_remaining_sent == 0
  357. break
  358. # cand_bbsz_idx contains beam indices for the top candidate
  359. # hypotheses, with a range of values: [0, bsz*beam_size),
  360. # and dimensions: [bsz, cand_size]
  361. cand_bbsz_idx = cand_beams.add(bbsz_offsets)
  362. # finalize hypotheses that end in eos
  363. eos_mask = cand_indices.eq(self.eos)
  364. finalized_sents = set()
  365. if step >= self.minlen:
  366. # only consider eos when it's among the top beam_size indices
  367. torch.masked_select(
  368. cand_bbsz_idx[:, :beam_size],
  369. mask=eos_mask[:, :beam_size],
  370. out=eos_bbsz_idx,
  371. )
  372. if eos_bbsz_idx.numel() > 0:
  373. torch.masked_select(
  374. cand_scores[:, :beam_size],
  375. mask=eos_mask[:, :beam_size],
  376. out=eos_scores,
  377. )
  378. finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores, cand_scores)
  379. num_remaining_sent -= len(finalized_sents)
  380. assert num_remaining_sent >= 0
  381. if num_remaining_sent == 0:
  382. break
  383. assert step < maxlen
  384. if len(finalized_sents) > 0:
  385. new_bsz = bsz - len(finalized_sents)
  386. # construct batch_idxs which holds indices of batches to keep for the next pass
  387. batch_mask = cand_indices.new_ones(bsz)
  388. batch_mask[cand_indices.new(finalized_sents)] = 0
  389. batch_idxs = batch_mask.nonzero().squeeze(-1)
  390. eos_mask = eos_mask[batch_idxs]
  391. cand_beams = cand_beams[batch_idxs]
  392. bbsz_offsets.resize_(new_bsz, 1)
  393. cand_bbsz_idx = cand_beams.add(bbsz_offsets)
  394. cand_scores = cand_scores[batch_idxs]
  395. cand_indices = cand_indices[batch_idxs]
  396. if prefix_tokens is not None:
  397. prefix_tokens = prefix_tokens[batch_idxs]
  398. src_lengths = src_lengths[batch_idxs]
  399. scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
  400. scores_buf.resize_as_(scores)
  401. tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
  402. tokens_buf.resize_as_(tokens)
  403. if attn is not None:
  404. attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
  405. attn_buf.resize_as_(attn)
  406. bsz = new_bsz
  407. else:
  408. batch_idxs = None
  409. # set active_mask so that values > cand_size indicate eos hypos
  410. # and values < cand_size indicate candidate active hypos.
  411. # After, the min values per row are the top candidate active hypos
  412. active_mask = buffer('active_mask')
  413. torch.add(
  414. eos_mask.type_as(cand_offsets) * cand_size,
  415. cand_offsets[:eos_mask.size(1)],
  416. out=active_mask,
  417. )
  418. # get the top beam_size active hypotheses, which are just the hypos
  419. # with the smallest values in active_mask
  420. active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
  421. torch.topk(
  422. active_mask, k=beam_size, dim=1, largest=False,
  423. out=(_ignore, active_hypos)
  424. )
  425. active_bbsz_idx = buffer('active_bbsz_idx')
  426. torch.gather(
  427. cand_bbsz_idx, dim=1, index=active_hypos,
  428. out=active_bbsz_idx,
  429. )
  430. active_scores = torch.gather(
  431. cand_scores, dim=1, index=active_hypos,
  432. out=scores[:, step].view(bsz, beam_size),
  433. )
  434. active_bbsz_idx = active_bbsz_idx.view(-1)
  435. active_scores = active_scores.view(-1)
  436. # copy tokens and scores for active hypotheses
  437. torch.index_select(
  438. tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
  439. out=tokens_buf[:, :step + 1],
  440. )
  441. torch.gather(
  442. cand_indices, dim=1, index=active_hypos,
  443. out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
  444. )
  445. if step > 0:
  446. torch.index_select(
  447. scores[:, :step], dim=0, index=active_bbsz_idx,
  448. out=scores_buf[:, :step],
  449. )
  450. torch.gather(
  451. cand_scores, dim=1, index=active_hypos,
  452. out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
  453. )
  454. # copy attention for active hypotheses
  455. if attn is not None:
  456. torch.index_select(
  457. attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
  458. out=attn_buf[:, :, :step + 2],
  459. )
  460. # swap buffers
  461. tokens, tokens_buf = tokens_buf, tokens
  462. scores, scores_buf = scores_buf, scores
  463. if attn is not None:
  464. attn, attn_buf = attn_buf, attn
  465. # reorder incremental state in decoder
  466. reorder_state = active_bbsz_idx
  467. # sort by score descending
  468. for sent in range(len(finalized)):
  469. finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
  470. return finalized
  471. def _decode(self, tokens, encoder_outs, incremental_states):
  472. if len(self.models) == 1:
  473. return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
  474. log_probs = []
  475. avg_attn = None
  476. for model, encoder_out in zip(self.models, encoder_outs):
  477. probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=True)
  478. log_probs.append(probs)
  479. if attn is not None:
  480. if avg_attn is None:
  481. avg_attn = attn
  482. else:
  483. avg_attn.add_(attn)
  484. avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(len(self.models))
  485. if avg_attn is not None:
  486. avg_attn.div_(len(self.models))
  487. return avg_probs, avg_attn
  488. def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
  489. with torch.no_grad():
  490. if incremental_states[model] is not None:
  491. decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
  492. else:
  493. decoder_out = list(model.decoder(tokens, encoder_out))
  494. decoder_out[0] = decoder_out[0][:, -1, :]
  495. attn = decoder_out[1]
  496. if type(attn) is dict:
  497. attn = attn['attn']
  498. if attn is not None:
  499. if type(attn) is dict:
  500. attn = attn['attn']
  501. attn = attn[:, -1, :]
  502. probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
  503. return probs, attn
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...