1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
- #!/usr/bin/env python3
- import fire
- import json
- import os
- import numpy as np
- import tensorflow as tf
- import model, sample, encoder
- def sample_model(
- model_name='124M',
- seed=None,
- nsamples=0,
- batch_size=1,
- length=None,
- temperature=1,
- top_k=0,
- top_p=1,
- models_dir='models',
- ):
- """
- Run the sample_model
- :model_name=124M : String, which model to use
- :seed=None : Integer seed for random number generators, fix seed to
- reproduce results
- :nsamples=0 : Number of samples to return, if 0, continues to
- generate samples indefinately.
- :batch_size=1 : Number of batches (only affects speed/memory).
- :length=None : Number of tokens in generated text, if None (default), is
- determined by model hyperparameters
- :temperature=1 : Float value controlling randomness in boltzmann
- distribution. Lower temperature results in less random completions. As the
- temperature approaches zero, the model will become deterministic and
- repetitive. Higher temperature results in more random completions.
- :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
- considered for each step (token), resulting in deterministic completions,
- while 40 means 40 words are considered at each step. 0 (default) is a
- special setting meaning no restrictions. 40 generally is a good value.
- :models_dir : path to parent folder containing model subfolders
- (i.e. contains the <model_name> folder)
- """
- models_dir = os.path.expanduser(os.path.expandvars(models_dir))
- enc = encoder.get_encoder(model_name, models_dir)
- hparams = model.default_hparams()
- with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
- hparams.override_from_dict(json.load(f))
- if length is None:
- length = hparams.n_ctx
- elif length > hparams.n_ctx:
- raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
- with tf.Session(graph=tf.Graph()) as sess:
- np.random.seed(seed)
- tf.set_random_seed(seed)
- output = sample.sample_sequence(
- hparams=hparams, length=length,
- start_token=enc.encoder['<|endoftext|>'],
- batch_size=batch_size,
- temperature=temperature, top_k=top_k, top_p=top_p
- )[:, 1:]
- saver = tf.train.Saver()
- ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
- saver.restore(sess, ckpt)
- generated = 0
- while nsamples == 0 or generated < nsamples:
- out = sess.run(output)
- for i in range(batch_size):
- generated += batch_size
- text = enc.decode(out[i])
- print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
- print(text)
- if __name__ == '__main__':
- fire.Fire(sample_model)
|