Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

generate_unconditional_samples.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  1. #!/usr/bin/env python3
  2. import fire
  3. import json
  4. import os
  5. import numpy as np
  6. import tensorflow as tf
  7. import model, sample, encoder
  8. def sample_model(
  9. model_name='124M',
  10. seed=None,
  11. nsamples=0,
  12. batch_size=1,
  13. length=None,
  14. temperature=1,
  15. top_k=0,
  16. top_p=1,
  17. models_dir='models',
  18. ):
  19. """
  20. Run the sample_model
  21. :model_name=124M : String, which model to use
  22. :seed=None : Integer seed for random number generators, fix seed to
  23. reproduce results
  24. :nsamples=0 : Number of samples to return, if 0, continues to
  25. generate samples indefinately.
  26. :batch_size=1 : Number of batches (only affects speed/memory).
  27. :length=None : Number of tokens in generated text, if None (default), is
  28. determined by model hyperparameters
  29. :temperature=1 : Float value controlling randomness in boltzmann
  30. distribution. Lower temperature results in less random completions. As the
  31. temperature approaches zero, the model will become deterministic and
  32. repetitive. Higher temperature results in more random completions.
  33. :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
  34. considered for each step (token), resulting in deterministic completions,
  35. while 40 means 40 words are considered at each step. 0 (default) is a
  36. special setting meaning no restrictions. 40 generally is a good value.
  37. :models_dir : path to parent folder containing model subfolders
  38. (i.e. contains the <model_name> folder)
  39. """
  40. models_dir = os.path.expanduser(os.path.expandvars(models_dir))
  41. enc = encoder.get_encoder(model_name, models_dir)
  42. hparams = model.default_hparams()
  43. with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
  44. hparams.override_from_dict(json.load(f))
  45. if length is None:
  46. length = hparams.n_ctx
  47. elif length > hparams.n_ctx:
  48. raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
  49. with tf.Session(graph=tf.Graph()) as sess:
  50. np.random.seed(seed)
  51. tf.set_random_seed(seed)
  52. output = sample.sample_sequence(
  53. hparams=hparams, length=length,
  54. start_token=enc.encoder['<|endoftext|>'],
  55. batch_size=batch_size,
  56. temperature=temperature, top_k=top_k, top_p=top_p
  57. )[:, 1:]
  58. saver = tf.train.Saver()
  59. ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
  60. saver.restore(sess, ckpt)
  61. generated = 0
  62. while nsamples == 0 or generated < nsamples:
  63. out = sess.run(output)
  64. for i in range(batch_size):
  65. generated += batch_size
  66. text = enc.decode(out[i])
  67. print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
  68. print(text)
  69. if __name__ == '__main__':
  70. fire.Fire(sample_model)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...