Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sample.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  1. import tensorflow as tf
  2. import model
  3. def top_k_logits(logits, k):
  4. if k == 0:
  5. # no truncation
  6. return logits
  7. def _top_k():
  8. values, _ = tf.nn.top_k(logits, k=k)
  9. min_values = values[:, -1, tf.newaxis]
  10. return tf.where(
  11. logits < min_values,
  12. tf.ones_like(logits, dtype=logits.dtype) * -1e10,
  13. logits,
  14. )
  15. return tf.cond(
  16. tf.equal(k, 0),
  17. lambda: logits,
  18. lambda: _top_k(),
  19. )
  20. def top_p_logits(logits, p):
  21. """Nucleus sampling"""
  22. batch, _ = logits.shape.as_list()
  23. sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
  24. cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
  25. indices = tf.stack([
  26. tf.range(0, batch),
  27. # number of indices to include
  28. tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
  29. ], axis=-1)
  30. min_values = tf.gather_nd(sorted_logits, indices)
  31. return tf.where(
  32. logits < min_values,
  33. tf.ones_like(logits) * -1e10,
  34. logits,
  35. )
  36. def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
  37. if start_token is None:
  38. assert context is not None, 'Specify exactly one of start_token and context!'
  39. else:
  40. assert context is None, 'Specify exactly one of start_token and context!'
  41. context = tf.fill([batch_size, 1], start_token)
  42. def step(hparams, tokens, past=None):
  43. lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
  44. logits = lm_output['logits'][:, :, :hparams.n_vocab]
  45. presents = lm_output['present']
  46. presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
  47. return {
  48. 'logits': logits,
  49. 'presents': presents,
  50. }
  51. with tf.name_scope('sample_sequence'):
  52. def body(past, prev, output):
  53. next_outputs = step(hparams, prev, past=past)
  54. logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
  55. logits = top_k_logits(logits, k=top_k)
  56. logits = top_p_logits(logits, p=top_p)
  57. samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
  58. return [
  59. next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
  60. samples,
  61. tf.concat([output, samples], axis=1)
  62. ]
  63. past, prev, output = body(None, context, context)
  64. def cond(*args):
  65. return True
  66. _, _, tokens = tf.while_loop(
  67. cond=cond, body=body,
  68. maximum_iterations=length - 1,
  69. loop_vars=[
  70. past,
  71. prev,
  72. output
  73. ],
  74. shape_invariants=[
  75. tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
  76. tf.TensorShape([batch_size, None]),
  77. tf.TensorShape([batch_size, None]),
  78. ],
  79. back_prop=False,
  80. )
  81. return tokens
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...