Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

interactive_conditional_samples.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. #!/usr/bin/env python3
  2. import fire
  3. import json
  4. import os
  5. import numpy as np
  6. import tensorflow as tf
  7. import model, sample, encoder
  8. def interact_model(
  9. model_name='124M',
  10. seed=None,
  11. nsamples=1,
  12. batch_size=1,
  13. length=None,
  14. temperature=1,
  15. top_k=0,
  16. top_p=1,
  17. models_dir='models',
  18. ):
  19. """
  20. Interactively run the model
  21. :model_name=124M : String, which model to use
  22. :seed=None : Integer seed for random number generators, fix seed to reproduce
  23. results
  24. :nsamples=1 : Number of samples to return total
  25. :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples.
  26. :length=None : Number of tokens in generated text, if None (default), is
  27. determined by model hyperparameters
  28. :temperature=1 : Float value controlling randomness in boltzmann
  29. distribution. Lower temperature results in less random completions. As the
  30. temperature approaches zero, the model will become deterministic and
  31. repetitive. Higher temperature results in more random completions.
  32. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
  33. considered for each step (token), resulting in deterministic completions,
  34. while 40 means 40 words are considered at each step. 0 (default) is a
  35. special setting meaning no restrictions. 40 generally is a good value.
  36. :models_dir : path to parent folder containing model subfolders
  37. (i.e. contains the <model_name> folder)
  38. """
  39. models_dir = os.path.expanduser(os.path.expandvars(models_dir))
  40. if batch_size is None:
  41. batch_size = 1
  42. assert nsamples % batch_size == 0
  43. enc = encoder.get_encoder(model_name, models_dir)
  44. hparams = model.default_hparams()
  45. with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
  46. hparams.override_from_dict(json.load(f))
  47. if length is None:
  48. length = hparams.n_ctx // 2
  49. elif length > hparams.n_ctx:
  50. raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
  51. with tf.Session(graph=tf.Graph()) as sess:
  52. context = tf.placeholder(tf.int32, [batch_size, None])
  53. np.random.seed(seed)
  54. tf.set_random_seed(seed)
  55. output = sample.sample_sequence(
  56. hparams=hparams, length=length,
  57. context=context,
  58. batch_size=batch_size,
  59. temperature=temperature, top_k=top_k, top_p=top_p
  60. )
  61. saver = tf.train.Saver()
  62. ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
  63. saver.restore(sess, ckpt)
  64. while True:
  65. raw_text = input("Model prompt >>> ")
  66. while not raw_text:
  67. print('Prompt should not be empty!')
  68. raw_text = input("Model prompt >>> ")
  69. context_tokens = enc.encode(raw_text)
  70. generated = 0
  71. for _ in range(nsamples // batch_size):
  72. out = sess.run(output, feed_dict={
  73. context: [context_tokens for _ in range(batch_size)]
  74. })[:, len(context_tokens):]
  75. for i in range(batch_size):
  76. generated += 1
  77. text = enc.decode(out[i])
  78. print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
  79. print(text)
  80. print("=" * 80)
  81. if __name__ == '__main__':
  82. fire.Fire(interact_model)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...